void Sequential<InputDataType, OutputDataType, CustomLayers...>::Backward( const arma::Mat<eT>&& /* input */, arma::Mat<eT>&& gy, arma::Mat<eT>&& g) { boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor( outputParameterVisitor, network.back())), std::move(gy), std::move(boost::apply_visitor(deltaVisitor, network.back()))), network.back()); for (size_t i = 2; i < network.size() + 1; ++i) { boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor( outputParameterVisitor, network[network.size() - i])), std::move( boost::apply_visitor(deltaVisitor, network[network.size() - i + 1])), std::move(boost::apply_visitor(deltaVisitor, network[network.size() - i]))), network[network.size() - i]); } g = boost::apply_visitor(deltaVisitor, network.front()); }
void Recurrent<InputDataType, OutputDataType>::Backward( const arma::Mat<eT>&& /* input */, arma::Mat<eT>&& gy, arma::Mat<eT>&& g) { if (!recurrentError.is_empty()) { recurrentError += gy; } else { recurrentError = gy; } if (backwardStep < (rho - 1)) { boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor( outputParameterVisitor, recurrentModule)), std::move(recurrentError), std::move(boost::apply_visitor(deltaVisitor, recurrentModule))), recurrentModule); boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor( outputParameterVisitor, inputModule)), std::move( boost::apply_visitor(deltaVisitor, recurrentModule)), std::move(g)), inputModule); boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor( outputParameterVisitor, feedbackModule)), std::move( boost::apply_visitor(deltaVisitor, recurrentModule)), std::move( boost::apply_visitor(deltaVisitor, feedbackModule))), feedbackModule); } else { boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor( outputParameterVisitor, initialModule)), std::move(recurrentError), std::move(g)), initialModule); } recurrentError = boost::apply_visitor(deltaVisitor, feedbackModule); backwardStep++; }
void Concat<InputDataType, OutputDataType>::Backward( const arma::Mat<eT>&& /* input */, arma::Mat<eT>&& gy, arma::Mat<eT>&& g) { size_t outSize = 0; size_t elements = 0; for (size_t i = 0, j = 0; i < network.size(); ++i, j += elements) { elements = boost::apply_visitor(outputParameterVisitor, network[i]).n_elem; arma::mat delta; if (gy.n_cols == 1) { delta = gy.submat(j, 0, j + elements - 1, 0); } else { delta = gy.submat(0, i, elements - 1, i); } boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor( outputParameterVisitor, network[i])), std::move(delta), std::move( boost::apply_visitor(deltaVisitor, network[i]))), network[i]); if (boost::apply_visitor(deltaVisitor, network[i]).n_elem > outSize) { outSize = boost::apply_visitor(deltaVisitor, network[i]).n_elem; } if (same) { if (i == 0) { g = std::move(boost::apply_visitor(deltaVisitor, network[i])); } else { g += std::move(boost::apply_visitor(deltaVisitor, network[i])); } } } if (!same) { g = arma::zeros(outSize, network.size()); for (size_t i = 0; i < network.size(); ++i) { size_t elements = boost::apply_visitor(deltaVisitor, network[i]).n_elem; if (elements < outSize) { g.submat(0, i, elements - 1, i) = arma::vectorise( boost::apply_visitor(deltaVisitor, network[i])); } else { g.col(i) = arma::vectorise( boost::apply_visitor(deltaVisitor, network[i])); } } } }
void RecurrentAttention<InputDataType, OutputDataType>::Backward( const arma::Mat<eT>&& /* input */, arma::Mat<eT>&& gy, arma::Mat<eT>&& g) { if (intermediateGradient.is_empty() && backwardStep == 0) { // Initialize the attention gradients. size_t weights = boost::apply_visitor(weightSizeVisitor, rnnModule) + boost::apply_visitor(weightSizeVisitor, actionModule); intermediateGradient = arma::zeros(weights, 1); attentionGradient = arma::zeros(weights, 1); // Initialize the action error. actionError = arma::zeros( boost::apply_visitor(outputParameterVisitor, actionModule).n_rows, boost::apply_visitor(outputParameterVisitor, actionModule).n_cols); } // Propagate the attention gradients. if (backwardStep == 0) { size_t offset = 0; offset += boost::apply_visitor(GradientSetVisitor( std::move(intermediateGradient), offset), rnnModule); boost::apply_visitor(GradientSetVisitor( std::move(intermediateGradient), offset), actionModule); attentionGradient.zeros(); } // Back-propagate through time. for (; backwardStep < rho; backwardStep++) { if (backwardStep == 0) { recurrentError = gy; } else { recurrentError = actionDelta; } for (size_t l = 0; l < network.size(); ++l) { boost::apply_visitor(LoadOutputParameterVisitor( std::move(moduleOutputParameter)), network[network.size() - 1 - l]); } if (backwardStep == (rho - 1)) { boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor( outputParameterVisitor, actionModule)), std::move(actionError), std::move(actionDelta)), actionModule); } else { boost::apply_visitor(BackwardVisitor(std::move(initialInput), std::move(actionError), std::move(actionDelta)), actionModule); } boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor( outputParameterVisitor, rnnModule)), std::move(recurrentError), std::move(rnnDelta)), rnnModule); if (backwardStep == 0) { g = rnnDelta.col(1); } else { g += rnnDelta.col(1); } IntermediateGradient(); } }