void ProjectionLayer::BackPropagate( MatrixBase& outDiff, const MatrixBase& in, const MatrixBase& out, float learningRate, MatrixBase* inDiff ) { int n_example = outDiff.Rows(); int n_project = m_projection.Columns(); int n_order = outDiff.Columns() / n_project; ASSERT( n_order == in.Columns() ); bool useMomentum = m_momentum != 0.0f && m_weight_decay != 0.0f; if ( m_norm_type == "Clip" ) { outDiff.Clip( -m_norm_param, m_norm_param ); } if ( m_norm_type == "L2Norm" ) { float norm = outDiff.L2Norm(); if ( norm > m_norm_param ) { outDiff.Scale( 1.0f / norm ); } } if ( useMomentum && (m_gradient.Rows() != m_projection.Rows() || m_gradient.Columns() != m_projection.Columns() ) ) { m_gradient.Reshape( m_projection.Rows(), m_projection.Columns(), kSetZero ); } for ( int i = 0; i < n_order; i++ ) { SubMatrix gradient( outDiff, MatrixRange(0, n_example, i * n_project, (i + 1) * n_project ) ); SubMatrix index( in, MatrixRange(0, n_example, i, i + 1) ); if ( useMomentum ) { m_gradient.Add( m_momentum, -learningRate * m_weight_decay, m_projection ); m_gradient.AddRows( gradient, index, -learningRate / n_example ); } else { m_projection.AddRows( gradient, index, -learningRate / n_example ); } } if ( useMomentum ) { m_projection.Add( 1.0f, 1.0f, m_gradient ); } } // end of BackPropagate
void vFSMNLayer::Compute( const MatrixBase& feature, MatrixBase* output ) { m_linear.Compute( feature, output ); m_memory.Reshape( feature.Rows(), feature.Columns() ); m_memory.VfsmnMemory( feature, m_filter, m_position ); output->Sgemm( 1.0f, 1.0f, m_memory, CUBLAS_OP_N, m_weight, CUBLAS_OP_N ); output->ReLU( *output ); }
void sFSMNLayer::Compute( const MatrixBase& feature, MatrixBase* output ) { m_linear.Compute( feature, output ); m_memory.Reshape( feature.Rows(), feature.Columns() ); m_memory.Strmm( 1.0f, m_block_diagonal, CUBLAS_SIDE_LEFT, CUBLAS_FILL_MODE_LOWER, feature, CUBLAS_OP_N ); // @xmb20160226 // m_memory.Sgemm( 0.0f, 1.0f, m_block_diagonal, CUBLAS_OP_N, feature, CUBLAS_OP_N ); output->Sgemm( 1.0f, 1.0f, m_memory, CUBLAS_OP_N, m_weight, CUBLAS_OP_N ); output->ReLU( *output ); }
void ProjectionLayer::Compute( const MatrixBase& feature, MatrixBase* output ) { int n_order = feature.Columns(); int n_example = feature.Rows(); int n_project = m_projection.Columns(); assert( feature.Rows() == output->Rows() ); assert( output->Columns() == n_order * m_projection.Columns() ); for ( int i = 0; i < n_order; i++ ) { SubMatrix gram( *output, MatrixRange(0, n_example, i * n_project, (i + 1) * n_project ) ); SubMatrix index( feature, MatrixRange(0, n_example, i, i + 1) ); gram.GetRows( m_projection, index ); } } // end of ForwardPropagate
void sFSMNLayer::BackPropagate( MatrixBase& outDiff, const MatrixBase& in, const MatrixBase& out, float learningRate, MatrixBase* inDiff ) { outDiff.DiffReLU( out, outDiff ); m_linear.BackPropagate( outDiff, in, Matrix(), learningRate, inDiff ); if ( m_w_momentum.Rows() != m_weight.Rows() || m_w_momentum.Columns() != m_weight.Columns() ) { m_w_momentum.Reshape( m_weight.Rows(), m_weight.Columns(), kSetZero ); } float avg_lr = -learningRate / (float)outDiff.Rows(); m_w_momentum.Sgemm( m_momentum, avg_lr, m_memory, CUBLAS_OP_T, outDiff, CUBLAS_OP_N ); if ( m_weight_decay != 0.0f ) { m_w_momentum.Add( 1.0f, -learningRate * m_weight_decay, m_weight ); } m_memory_diff.Reshape( m_memory.Rows(), m_memory.Columns() ); m_memory_diff.Sgemm( 0.0f, 1.0f, outDiff, CUBLAS_OP_N, m_weight, CUBLAS_OP_T ); if ( NULL != inDiff ) { inDiff->Sgemm( 1.0f, 1.0f, m_block_diagonal, CUBLAS_OP_T, m_memory_diff, CUBLAS_OP_N ); } m_diagonal_diff.Reshape( m_block_diagonal.Rows(), m_block_diagonal.Columns() ); m_diagonal_diff.Sgemm( 0.0f, avg_lr / in.Columns(), m_memory_diff, CUBLAS_OP_N, in, CUBLAS_OP_T ); if ( m_norm_type == "Clip" ) { float range = m_norm_param * (-avg_lr); m_w_momentum.Clip( -range, range ); m_diagonal_diff.Clip( -range, range ); } if ( m_norm_type == "L2Norm" ) { float w_norm = m_w_momentum.L2Norm() * (-avg_lr); if ( w_norm > m_norm_param ) { m_w_momentum.Scale( 1.0f / w_norm ); } float d_norm = m_diagonal_diff.L2Norm() * (-avg_lr); if ( d_norm > m_norm_param ) { m_diagonal_diff.Scale( 1.0f / d_norm ); } } m_weight.Add( 1.0f, 1.0f, m_w_momentum ); m_filter.UpdateSfsmnFilter( m_length, m_diagonal_diff ); }