Ejemplo n.º 1
0
void ConvBaseProjection::reshape(int batchSize) {
  size_t width = calOutputSize();
  CHECK_EQ(width, out_->value->getWidth());
  CHECK_EQ(calInputSize(), in_->value->getWidth());

  reshapeTensorDesc(batchSize);
  bool useDilation = false;
  if (dilationH_ > 1 || dilationW_ > 1) {
    useDilation = true;
  }
  hl_conv_workspace(imageDesc_,
                    outputDesc_,
                    filterDesc_,
                    convDesc_,
                    &fwdAlgo_,
                    &fwdLimitBytes_,
                    &bwdDataAlgo_,
                    &bwdDataLimitBytes_,
                    &bwdFilterAlgo_,
                    &bwdFilterLimitBytes_,
                    useDilation);

  size_t maxWorkSpace = 0;
  maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
  maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
  workSpaceInBytes_ = maxWorkSpace;

  VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
          << " / " << bwdDataAlgo_ << " / " << bwdFilterAlgo_;
}
Ejemplo n.º 2
0
void CudnnConvLayer::forward(PassType passType) {
  Layer::forward(passType);

  int batchSize = getInput(0).getBatchSize();
  resetOutput(batchSize, calOutputSize());

  for (size_t i = 0; i != inputLayers_.size(); ++i) {
    projections_[i]->forward(&getInput(i), &getOutput(), passType);
  }

  if (biases_) {
    REGISTER_TIMER_INFO("CudnnConvBiasTimer", getName().c_str());
    int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
    hl_tensor_reshape(outputDesc_, batchSize, numFilters_ / groups_[0],
        outputH_[0], outputW_[0], numFilters_ * outputH_[0] * outputW_[0],
        outputH_[0] * outputW_[0], outputW_[0], 1);
    outputOffset_ = getOutputValue()->getWidth() / groups_[0];
    for (int g = 0; g < groups_[0]; ++g) {
      real *biasData = biases_->getW()->getData() + biasOffset_ * g;
      real *outData = getOutputValue()->getData() + outputOffset_ * g;
      hl_convolution_forward_add_bias(biasDesc_, biasData,
                                      outputDesc_, outData);
    }
  }

  forwardActivation();
}