void ConvBaseOperator::allocConvWorkSpace() { hl_conv_workspace(imageDesc_, outputDesc_, filterDesc_, convDesc_, &fwdAlgo_, &fwdLimitBytes_, &bwdDataAlgo_, &bwdDataLimitBytes_, &bwdFilterAlgo_, &bwdFilterLimitBytes_); size_t maxWorkSpace = 0; maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_); maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_); if (maxWorkSpace > workSpaceInBytes_) { if (workSpaceInBytes_ != 0) { hl_free_mem_device(workSpace_); } // total amount of storage needed workSpace_ = hl_malloc_device(maxWorkSpace); workSpaceInBytes_ = maxWorkSpace; } }
void ConvBaseProjection::reshape(int batchSize) { size_t width = calOutputSize(); CHECK_EQ(width, out_->value->getWidth()); CHECK_EQ(calInputSize(), in_->value->getWidth()); reshapeTensorDesc(batchSize); bool useDilation = false; if (dilationH_ > 1 || dilationW_ > 1) { useDilation = true; } hl_conv_workspace(imageDesc_, outputDesc_, filterDesc_, convDesc_, &fwdAlgo_, &fwdLimitBytes_, &bwdDataAlgo_, &bwdDataLimitBytes_, &bwdFilterAlgo_, &bwdFilterLimitBytes_, useDilation); size_t maxWorkSpace = 0; maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_); maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_); workSpaceInBytes_ = maxWorkSpace; VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_ << " / " << bwdDataAlgo_ << " / " << bwdFilterAlgo_; }