示例#1
0
void maxpool3d::create_dX()
{
  if (dX.getNDims() != 0) return; // already created

  // size dX: the right size taking into account pad and stride
  mwSize szdX[5] = {0,0,0,1,1};
  szdX[0] = stride[0]*(dY.getSizeAtDim(0)-1) - (pad[0]+pad[1]) + pool[0];
  szdX[1] = stride[1]*(dY.getSizeAtDim(1)-1) - (pad[2]+pad[3]) + pool[1];
  szdX[2] = stride[2]*(dY.getSizeAtDim(2)-1) - (pad[4]+pad[5]) + pool[2];
  szdX[3] = dY.getSizeAtDim(3);
  szdX[4] = dY.getSizeAtDim(4);

  // create dX
  dX.setMxArray( createVol5dZeros(szdX, dY.dt) );
}
示例#2
0
void maxpool3d::create_dX()
{
  // check ind & dY
  if ( ind.getElemType()!=mxINT32_CLASS || dY.getElemType()!=mxSINGLE_CLASS ) 
    throw mp3d_ex("In bprop(): dY must be SINGLE, ind must be double.");

  //
  check_pad_pool();

  // size dX: the right size taking into account pad and stride
  mwSize szdX[5] = {0,0,0,1,1};
  szdX[0] = stride[0]*(dY.getSizeAtDim(0)-1) - (pad[0]+pad[1]) + pool[0];
  szdX[1] = stride[1]*(dY.getSizeAtDim(1)-1) - (pad[2]+pad[3]) + pool[1];
  szdX[2] = stride[2]*(dY.getSizeAtDim(2)-1) - (pad[4]+pad[5]) + pool[2];
  szdX[3] = dY.getSizeAtDim(3);
  szdX[4] = dY.getSizeAtDim(4);

  // create Y
  dX.setMxArray( createVol5dZeros(szdX, dY.dt) );
}
示例#3
0
maxpool3d* factory_mp3d_homebrew::parse_and_create(int no, mxArray *vo[], int ni, mxArray const *vi[])
{
  if (ni < 1) 
    throw mp3d_ex("Too few input arguments.");

  // fprop or bprop?
  maxpool3d holder;
  int opt_beg = -1;
  xpuMxArrayTW::DEV_TYPE dt;

  if (no == 2) { // fprop
    holder.X.setMxArray( (mxArray*) vi[0] ); // we won't change it
    dt = holder.X.getDevice();

    if ( ni < 1 || (holder.X.getElemType() != mxSINGLE_CLASS) ) 
      throw mp3d_ex("For fprop(), there should be at least one input, X, of SINGLE type,"
                    "be all gpuArray or be all mxArray.\n");

    holder.ct = maxpool3d::FPROP;
    opt_beg = 1;
  } 
  else if (no == 1) { // bprop
    holder.dY.setMxArray( (mxArray*)  vi[0]);
    holder.ind.setMxArray( (mxArray*) vi[1]);
    dt = holder.dY.getDevice();

    if ( ni < 2 || 
         holder.dY.getElemType()  != mxSINGLE_CLASS || 
         holder.ind.getElemType() != mxINT32_CLASS) 
      throw mp3d_ex("For bprop(): there should be at least 3 arguments, dzdY, ind.\n"
        "The dzdY must be SINGLE, the max index ind must be int32,"
        "they should be both gpuArray or be both mxArray.\n");

    holder.ct = maxpool3d::BPROP;
    opt_beg = 2;
  } 
  else {
    throw mp3d_ex("Unrecognized arguments/way of calling. "
      "The output should be either [Y, ind] (fprop) or ind (bprop). ");
  }

  // if bprop: create dX if szX is provided args:(dzdY, ind, szX)
  if (holder.ct == maxpool3d::BPROP) {
    if (ni >= 3 && !mxIsChar(vi[2]) ) {
      // szX provided, check it
      if (!mxIsDouble(vi[2])) mexErrMsgTxt("setCArray: pa must be double matrix\n");
      double *ptr = (double*)mxGetData(vi[2]);
      mwSize nelem = mxGetNumberOfElements(vi[2]);
      if (nelem > 5 || nelem < 3) 
        throw mp3d_ex("The third argument must be: 3 <= numel(szX) <= 5.\n");

      // get szX
      mwSize szX[5];
      szX[3] = szX[4] = 1;
      for (int i = 0; i < nelem; ++i) szX[i] = (mwSize) ptr[i];
      
      // create the dX
      holder.dX.setMxArray( createVol5dZeros(szX, holder.dY.dt) );

      // reset the option beginning
      opt_beg = 3;
    } 
    else // issue the warning
      mexWarnMsgTxt("For bprop(), the calling method with 2 args:\n"
        "[...] = mex_maxpool3d(dzdY, ind, ...)\n"
        "is deprecated, because this could cause ambiguity when inferring input X size.\n"
        "Use the new one to specify the size for input X (or dzdX) explicitly:\n"
        "[...] = mex_maxpool3d(dzdY, ind, szX,...)\n");
  }

  // set options
  set_options(holder, opt_beg, ni, vi);

  // check validity
  check_padpool(holder);

  // create the desired worker and set the parameters
#ifdef WITH_GPUARRAY
  if (dt == xpuMxArrayTW::GPU)
    return new maxpool3d_gpu(holder);
  else
    return new maxpool3d_cpu(holder);
#else
  return new maxpool3d_cpu(holder);
#endif // WITH_GPUARRAY
}