示例#1
0
void ConvBaseOperator::allocConvWorkSpace() {
  hl_conv_workspace(imageDesc_,
                    outputDesc_,
                    filterDesc_,
                    convDesc_,
                    &fwdAlgo_,
                    &fwdLimitBytes_,
                    &bwdDataAlgo_,
                    &bwdDataLimitBytes_,
                    &bwdFilterAlgo_,
                    &bwdFilterLimitBytes_);

  size_t maxWorkSpace = 0;
  maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
  maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);

  if (maxWorkSpace > workSpaceInBytes_) {
    if (workSpaceInBytes_ != 0) {
      hl_free_mem_device(workSpace_);
    }
    // total amount of storage needed
    workSpace_ = hl_malloc_device(maxWorkSpace);
    workSpaceInBytes_ = maxWorkSpace;
  }
}
示例#2
0
void ConvBaseProjection::reshape(int batchSize) {
  size_t width = calOutputSize();
  CHECK_EQ(width, out_->value->getWidth());
  CHECK_EQ(calInputSize(), in_->value->getWidth());

  reshapeTensorDesc(batchSize);
  bool useDilation = false;
  if (dilationH_ > 1 || dilationW_ > 1) {
    useDilation = true;
  }
  hl_conv_workspace(imageDesc_,
                    outputDesc_,
                    filterDesc_,
                    convDesc_,
                    &fwdAlgo_,
                    &fwdLimitBytes_,
                    &bwdDataAlgo_,
                    &bwdDataLimitBytes_,
                    &bwdFilterAlgo_,
                    &bwdFilterLimitBytes_,
                    useDilation);

  size_t maxWorkSpace = 0;
  maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
  maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
  workSpaceInBytes_ = maxWorkSpace;

  VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
          << " / " << bwdDataAlgo_ << " / " << bwdFilterAlgo_;
}