void ReLULayer::BackPropagate( MatrixBase& outDiff, const MatrixBase& in, const MatrixBase& out, float learningRate, MatrixBase* inDiff ) { outDiff.DiffReLU( out, outDiff ); m_linear.BackPropagate( outDiff, in, Matrix(), learningRate, inDiff ); } // end of BackPropagate
void vFSMNLayer::BackPropagate( MatrixBase& outDiff, const MatrixBase& in, const MatrixBase& out, float learningRate, MatrixBase* inDiff ) { outDiff.DiffReLU( out, outDiff ); m_linear.BackPropagate( outDiff, in, Matrix(), learningRate, inDiff ); if ( m_w_momentum.Rows() != m_weight.Rows() || m_w_momentum.Columns() != m_weight.Columns() ) { m_w_momentum.Reshape( m_weight.Rows(), m_weight.Columns(), kSetZero ); } float avg_lr = -learningRate / (float)outDiff.Rows(); m_w_momentum.Sgemm( m_momentum, avg_lr, m_memory, CUBLAS_OP_T, outDiff, CUBLAS_OP_N ); if ( m_weight_decay != 0.0f ) { m_w_momentum.Add( 1.0f, -learningRate * m_weight_decay, m_weight ); } m_memory_diff.Reshape( m_memory.Rows(), m_memory.Columns() ); m_memory_diff.Sgemm( 0.0f, 1.0f, outDiff, CUBLAS_OP_N, m_weight, CUBLAS_OP_T ); if ( NULL != inDiff ) { inDiff->ComputeVfsmnHiddenDiff( m_memory_diff, m_filter, m_position ); } if ( m_norm_type == "Clip" ) { float range = m_norm_param * (-avg_lr); m_w_momentum.Clip( -range, range ); } if ( m_norm_type == "L2Norm" ) { float w_norm = m_w_momentum.L2Norm() * (-avg_lr); if ( w_norm > m_norm_param ) { m_w_momentum.Scale( 1.0f / w_norm ); } } m_filter.UpdateVfsmnFilter( m_memory_diff, in, m_position, avg_lr ); m_weight.Add( 1.0f, 1.0f, m_w_momentum ); }