Exemplo n.º 1
0
void Sequential<InputDataType, OutputDataType, CustomLayers...>::Backward(
    const arma::Mat<eT>&& /* input */, arma::Mat<eT>&& gy, arma::Mat<eT>&& g)
{
  boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor(
      outputParameterVisitor, network.back())), std::move(gy),
      std::move(boost::apply_visitor(deltaVisitor, network.back()))),
      network.back());

  for (size_t i = 2; i < network.size() + 1; ++i)
  {
    boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor(
        outputParameterVisitor, network[network.size() - i])), std::move(
        boost::apply_visitor(deltaVisitor, network[network.size() - i + 1])),
        std::move(boost::apply_visitor(deltaVisitor,
        network[network.size() - i]))), network[network.size() - i]);
  }

  g = boost::apply_visitor(deltaVisitor, network.front());
}
Exemplo n.º 2
0
void Recurrent<InputDataType, OutputDataType>::Backward(
    const arma::Mat<eT>&& /* input */, arma::Mat<eT>&& gy, arma::Mat<eT>&& g)
{
  if (!recurrentError.is_empty())
  {
    recurrentError += gy;
  }
  else
  {
    recurrentError = gy;
  }

  if (backwardStep < (rho - 1))
  {
    boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor(
        outputParameterVisitor, recurrentModule)), std::move(recurrentError),
        std::move(boost::apply_visitor(deltaVisitor, recurrentModule))),
        recurrentModule);

    boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor(
        outputParameterVisitor, inputModule)), std::move(
        boost::apply_visitor(deltaVisitor, recurrentModule)), std::move(g)),
        inputModule);

    boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor(
        outputParameterVisitor, feedbackModule)), std::move(
        boost::apply_visitor(deltaVisitor, recurrentModule)), std::move(
        boost::apply_visitor(deltaVisitor, feedbackModule))), feedbackModule);
  }
  else
  {
    boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor(
        outputParameterVisitor, initialModule)), std::move(recurrentError),
        std::move(g)), initialModule);
  }

  recurrentError = boost::apply_visitor(deltaVisitor, feedbackModule);
  backwardStep++;
}
Exemplo n.º 3
0
void Concat<InputDataType, OutputDataType>::Backward(
    const arma::Mat<eT>&& /* input */, arma::Mat<eT>&& gy, arma::Mat<eT>&& g)
{
  size_t outSize = 0;
  size_t elements = 0;

  for (size_t i = 0, j = 0; i < network.size(); ++i, j += elements)
  {
    elements = boost::apply_visitor(outputParameterVisitor,
        network[i]).n_elem;

    arma::mat delta;
    if (gy.n_cols == 1)
    {
      delta = gy.submat(j, 0, j + elements - 1, 0);
    }
    else
    {
      delta = gy.submat(0, i, elements - 1, i);
    }

    boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor(
        outputParameterVisitor, network[i])), std::move(delta), std::move(
        boost::apply_visitor(deltaVisitor, network[i]))), network[i]);

    if (boost::apply_visitor(deltaVisitor, network[i]).n_elem > outSize)
    {
      outSize = boost::apply_visitor(deltaVisitor, network[i]).n_elem;
    }

    if (same)
    {
      if (i == 0)
      {
        g = std::move(boost::apply_visitor(deltaVisitor, network[i]));
      }
      else
      {
        g += std::move(boost::apply_visitor(deltaVisitor, network[i]));
      }
    }
  }

  if (!same)
  {
    g = arma::zeros(outSize, network.size());
    for (size_t i = 0; i < network.size(); ++i)
    {
      size_t elements = boost::apply_visitor(deltaVisitor, network[i]).n_elem;
      if (elements < outSize)
      {
        g.submat(0, i, elements - 1, i) = arma::vectorise(
            boost::apply_visitor(deltaVisitor, network[i]));
      }
      else
      {
        g.col(i) = arma::vectorise(
            boost::apply_visitor(deltaVisitor, network[i]));
      }
    }
  }
}
Exemplo n.º 4
0
void RecurrentAttention<InputDataType, OutputDataType>::Backward(
    const arma::Mat<eT>&& /* input */,
    arma::Mat<eT>&& gy,
    arma::Mat<eT>&& g)
{
  if (intermediateGradient.is_empty() && backwardStep == 0)
  {
    // Initialize the attention gradients.
    size_t weights = boost::apply_visitor(weightSizeVisitor, rnnModule) +
        boost::apply_visitor(weightSizeVisitor, actionModule);

    intermediateGradient = arma::zeros(weights, 1);
    attentionGradient = arma::zeros(weights, 1);

    // Initialize the action error.
    actionError = arma::zeros(
      boost::apply_visitor(outputParameterVisitor, actionModule).n_rows,
      boost::apply_visitor(outputParameterVisitor, actionModule).n_cols);
  }

  // Propagate the attention gradients.
  if (backwardStep == 0)
  {
    size_t offset = 0;
    offset += boost::apply_visitor(GradientSetVisitor(
        std::move(intermediateGradient), offset), rnnModule);
    boost::apply_visitor(GradientSetVisitor(
        std::move(intermediateGradient), offset), actionModule);

    attentionGradient.zeros();
  }

  // Back-propagate through time.
  for (; backwardStep < rho; backwardStep++)
  {
    if (backwardStep == 0)
    {
      recurrentError = gy;
    }
    else
    {
      recurrentError = actionDelta;
    }

    for (size_t l = 0; l < network.size(); ++l)
    {
      boost::apply_visitor(LoadOutputParameterVisitor(
         std::move(moduleOutputParameter)), network[network.size() - 1 - l]);
    }

    if (backwardStep == (rho - 1))
    {
      boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor(
          outputParameterVisitor, actionModule)), std::move(actionError),
          std::move(actionDelta)), actionModule);
    }
    else
    {
      boost::apply_visitor(BackwardVisitor(std::move(initialInput),
          std::move(actionError), std::move(actionDelta)), actionModule);
    }

    boost::apply_visitor(BackwardVisitor(std::move(boost::apply_visitor(
        outputParameterVisitor, rnnModule)), std::move(recurrentError),
        std::move(rnnDelta)), rnnModule);

    if (backwardStep == 0)
    {
      g = rnnDelta.col(1);
    }
    else
    {
      g += rnnDelta.col(1);
    }

    IntermediateGradient();
  }
}