コード例 #1
0
void ConvBaseProjection::initCudnn() {
  hl_create_filter_descriptor(&filterDesc_,
                              channels_ / groups_,
                              numFilters_ / groups_,
                              filterH_,
                              filterW_);
  hl_create_tensor_descriptor(&imageDesc_);
  hl_create_tensor_descriptor(&outputDesc_);
  hl_create_convolution_descriptor(&convDesc_,
                                   imageDesc_,
                                   filterDesc_,
                                   paddingH_,
                                   paddingW_,
                                   strideH_,
                                   strideW_,
                                   dilationH_,
                                   dilationW_);

  // initialize all to default algorithms
  fwdAlgo_ = 0;
  bwdFilterAlgo_ = 0;
  bwdDataAlgo_ = 0;
  fwdLimitBytes_ = 0;
  bwdDataLimitBytes_ = 0;
  bwdFilterLimitBytes_ = 0;
  workSpaceInBytes_ = 0;
}
コード例 #2
0
void ConvBaseOperator::computeConvSizes() {
  hl_create_filter_descriptor(
      &filterDesc_, channels_, numFilters_, filterSizeY_, filterSize_);
  hl_create_tensor_descriptor(&imageDesc_);
  hl_create_tensor_descriptor(&outputDesc_);
  hl_create_convolution_descriptor(&convDesc_,
                                   imageDesc_,
                                   filterDesc_,
                                   paddingY_,
                                   padding_,
                                   strideY_,
                                   stride_);
}