void Recurrent<InputDataType, OutputDataType>::Gradient( arma::Mat<eT>&& input, arma::Mat<eT>&& error, arma::Mat<eT>&& /* gradient */) { if (gradientStep < (rho - 1)) { boost::apply_visitor(GradientVisitor(std::move(input), std::move(error)), recurrentModule); boost::apply_visitor(GradientVisitor(std::move(input), std::move( boost::apply_visitor(deltaVisitor, mergeModule))), inputModule); boost::apply_visitor(GradientVisitor(std::move( feedbackOutputParameter[feedbackOutputParameter.size() - 2 - gradientStep]), std::move(boost::apply_visitor(deltaVisitor, mergeModule))), feedbackModule); } else { boost::apply_visitor(GradientZeroVisitor(), recurrentModule); boost::apply_visitor(GradientZeroVisitor(), inputModule); boost::apply_visitor(GradientZeroVisitor(), feedbackModule); boost::apply_visitor(GradientVisitor(std::move(input), std::move( boost::apply_visitor(deltaVisitor, startModule))), initialModule); } gradientStep++; if (gradientStep == rho) { gradientStep = 0; feedbackOutputParameter.clear(); } }
void Concat<InputDataType, OutputDataType>::Gradient( arma::Mat<eT>&& /* input */, arma::Mat<eT>&& error, arma::Mat<eT>&& /* gradient */) { for (size_t i = 0; i < network.size(); ++i) { boost::apply_visitor(GradientVisitor(std::move(boost::apply_visitor( outputParameterVisitor, network[i])), std::move(error)), network[i]); } }
void Sequential<InputDataType, OutputDataType, CustomLayers...>::Gradient( arma::Mat<eT>&& input, arma::Mat<eT>&& error, arma::Mat<eT>&& /* gradient */) { boost::apply_visitor(GradientVisitor(std::move(boost::apply_visitor( outputParameterVisitor, network[network.size() - 2])), std::move(error)), network.back()); for (size_t i = 2; i < network.size(); ++i) { boost::apply_visitor(GradientVisitor(std::move(boost::apply_visitor( outputParameterVisitor, network[network.size() - i - 1])), std::move( boost::apply_visitor(deltaVisitor, network[network.size() - i + 1]))), network[network.size() - i]); } boost::apply_visitor(GradientVisitor(std::move(input), std::move( boost::apply_visitor(deltaVisitor, network[1]))), network.front()); }