Exemplo n.º 1
0
void ReLULayer::BackPropagate( MatrixBase& outDiff, 
                               const MatrixBase& in,
                               const MatrixBase& out,
                               float learningRate,
                               MatrixBase* inDiff ) {
    outDiff.DiffReLU( out, outDiff );
    m_linear.BackPropagate( outDiff, in, Matrix(), learningRate, inDiff );
}    // end of BackPropagate
Exemplo n.º 2
0
void vFSMNLayer::BackPropagate( MatrixBase& outDiff, 
                                const MatrixBase& in,
                                const MatrixBase& out,
                                float learningRate,
                                MatrixBase* inDiff ) {
    outDiff.DiffReLU( out, outDiff );
    m_linear.BackPropagate( outDiff, in, Matrix(), learningRate, inDiff );

    if ( m_w_momentum.Rows() != m_weight.Rows() || 
         m_w_momentum.Columns() != m_weight.Columns() ) {
        m_w_momentum.Reshape( m_weight.Rows(), m_weight.Columns(), kSetZero );
    }    

    float avg_lr = -learningRate / (float)outDiff.Rows();

    m_w_momentum.Sgemm( m_momentum,
                        avg_lr,
                        m_memory,
                        CUBLAS_OP_T,
                        outDiff,
                        CUBLAS_OP_N );

    if ( m_weight_decay != 0.0f ) {
        m_w_momentum.Add( 1.0f, -learningRate * m_weight_decay, m_weight );
    }

    m_memory_diff.Reshape( m_memory.Rows(), m_memory.Columns() );
    m_memory_diff.Sgemm( 0.0f, 1.0f, outDiff, CUBLAS_OP_N, m_weight, CUBLAS_OP_T );

    if ( NULL != inDiff ) {
        inDiff->ComputeVfsmnHiddenDiff( m_memory_diff, m_filter, m_position );
    }

    if ( m_norm_type == "Clip" ) {
        float range = m_norm_param * (-avg_lr);
        m_w_momentum.Clip( -range, range );
    }

    if ( m_norm_type == "L2Norm" ) {
        float w_norm = m_w_momentum.L2Norm() * (-avg_lr);
        if ( w_norm > m_norm_param ) {
            m_w_momentum.Scale( 1.0f / w_norm );
        }
    }

    m_filter.UpdateVfsmnFilter( m_memory_diff, in, m_position, avg_lr );
    m_weight.Add( 1.0f, 1.0f, m_w_momentum );
}