void InterpolationLayer::backward(const UpdateCallback& callback) { MatrixPtr outG = getOutputGrad(); MatrixPtr weightV = getInputValue(0); MatrixPtr inV1 = getInputValue(1); MatrixPtr inV2 = getInputValue(2); MatrixPtr inG0 = getInputGrad(0); MatrixPtr inG1 = getInputGrad(1); MatrixPtr inG2 = getInputGrad(2); size_t batchSize = inV1->getHeight(); size_t dataDim = inV1->getWidth(); REGISTER_TIMER_INFO("BwInterpTimer", getName().c_str()); if (inG0) { Matrix::resizeOrCreate(tmpMatrix, batchSize, dataDim, false, useGpu_); // inG0 += outG .* (inV1 - inV2) tmpMatrix->sub(*inV1, *inV2); inG0->rowDotMul(0, *outG, *tmpMatrix); } if (inG1) { // inG1 += outG * weight inG1->addRowScale(0, *outG, *weightV); } if (inG2) { // inG2 += outG * weightLast inG2->addRowScale(0, *outG, *weightLast_); } }
void InterpolationLayer::forward(PassType passType) { Layer::forward(passType); MatrixPtr weightV = getInputValue(0); MatrixPtr inV1 = getInputValue(1); MatrixPtr inV2 = getInputValue(2); size_t batchSize = inV1->getHeight(); size_t dataDim = inV1->getWidth(); CHECK_EQ(dataDim, getSize()); CHECK_EQ(dataDim, inV2->getWidth()); CHECK_EQ(batchSize, inV1->getHeight()); CHECK_EQ(batchSize, inV2->getHeight()); { REGISTER_TIMER_INFO("FwResetTimer", getName().c_str()); resetOutput(batchSize, dataDim); } MatrixPtr outV = getOutputValue(); Matrix::resizeOrCreate(weightLast_, batchSize, 1, false, useGpu_); weightLast_->one(); weightLast_->sub(*weightV); REGISTER_TIMER_INFO("FwInterpTimer", getName().c_str()); // outV = inV1 * weight + inV2 * weightLast outV->addRowScale(0, *inV1, *weightV); outV->addRowScale(0, *inV2, *weightLast_); }