示例#1
0
void InterpolationLayer::backward(const UpdateCallback& callback) {
  MatrixPtr outG = getOutputGrad();
  MatrixPtr weightV = getInputValue(0);
  MatrixPtr inV1 = getInputValue(1);
  MatrixPtr inV2 = getInputValue(2);
  MatrixPtr inG0 = getInputGrad(0);
  MatrixPtr inG1 = getInputGrad(1);
  MatrixPtr inG2 = getInputGrad(2);

  size_t batchSize = inV1->getHeight();
  size_t dataDim = inV1->getWidth();

  REGISTER_TIMER_INFO("BwInterpTimer", getName().c_str());

  if (inG0) {
    Matrix::resizeOrCreate(tmpMatrix, batchSize, dataDim, false, useGpu_);

    // inG0 += outG .* (inV1 - inV2)
    tmpMatrix->sub(*inV1, *inV2);
    inG0->rowDotMul(0, *outG, *tmpMatrix);
  }

  if (inG1) {
    // inG1 += outG * weight
    inG1->addRowScale(0, *outG, *weightV);
  }

  if (inG2) {
    // inG2 += outG * weightLast
    inG2->addRowScale(0, *outG, *weightLast_);
  }
}
示例#2
0
void InterpolationLayer::forward(PassType passType) {
  Layer::forward(passType);

  MatrixPtr weightV = getInputValue(0);
  MatrixPtr inV1 = getInputValue(1);
  MatrixPtr inV2 = getInputValue(2);

  size_t batchSize = inV1->getHeight();
  size_t dataDim = inV1->getWidth();

  CHECK_EQ(dataDim, getSize());
  CHECK_EQ(dataDim, inV2->getWidth());
  CHECK_EQ(batchSize, inV1->getHeight());
  CHECK_EQ(batchSize, inV2->getHeight());

  {
    REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
    resetOutput(batchSize, dataDim);
  }

  MatrixPtr outV = getOutputValue();

  Matrix::resizeOrCreate(weightLast_, batchSize, 1, false, useGpu_);
  weightLast_->one();
  weightLast_->sub(*weightV);

  REGISTER_TIMER_INFO("FwInterpTimer", getName().c_str());
  // outV = inV1 * weight + inV2 * weightLast
  outV->addRowScale(0, *inV1, *weightV);
  outV->addRowScale(0, *inV2, *weightLast_);
}