Beispiel #1
0
void factory_mp3d_homebrew::set_pad(maxpool3d &h, mxArray const *pa )
{
  if ( !setCArray<mwSize, 6>(pa, h.pad) )
    throw mp3d_ex("The length of option pad must be 1 or 6.");
}
Beispiel #2
0
maxpool3d* factory_mp3d_homebrew::parse_and_create(int no, mxArray *vo[], int ni, mxArray const *vi[])
{
  if (ni < 1) 
    throw mp3d_ex("Too few input arguments.");

  // fprop or bprop?
  maxpool3d holder;
  int opt_beg = -1;
  xpuMxArrayTW::DEV_TYPE dt;

  if (no == 2) { // fprop
    holder.X.setMxArray( (mxArray*) vi[0] ); // we won't change it
    dt = holder.X.getDevice();

    if ( ni < 1 || (holder.X.getElemType() != mxSINGLE_CLASS) ) 
      throw mp3d_ex("For fprop(), there should be at least one input, X, of SINGLE type,"
                    "be all gpuArray or be all mxArray.\n");

    holder.ct = maxpool3d::FPROP;
    opt_beg = 1;
  } 
  else if (no == 1) { // bprop
    holder.dY.setMxArray( (mxArray*)  vi[0]);
    holder.ind.setMxArray( (mxArray*) vi[1]);
    dt = holder.dY.getDevice();

    if ( ni < 2 || 
         holder.dY.getElemType()  != mxSINGLE_CLASS || 
         holder.ind.getElemType() != mxINT32_CLASS) 
      throw mp3d_ex("For bprop(): there should be at least 3 arguments, dzdY, ind.\n"
        "The dzdY must be SINGLE, the max index ind must be int32,"
        "they should be both gpuArray or be both mxArray.\n");

    holder.ct = maxpool3d::BPROP;
    opt_beg = 2;
  } 
  else {
    throw mp3d_ex("Unrecognized arguments/way of calling. "
      "The output should be either [Y, ind] (fprop) or ind (bprop). ");
  }

  // if bprop: create dX if szX is provided args:(dzdY, ind, szX)
  if (holder.ct == maxpool3d::BPROP) {
    if (ni >= 3 && !mxIsChar(vi[2]) ) {
      // szX provided, check it
      if (!mxIsDouble(vi[2])) mexErrMsgTxt("setCArray: pa must be double matrix\n");
      double *ptr = (double*)mxGetData(vi[2]);
      mwSize nelem = mxGetNumberOfElements(vi[2]);
      if (nelem > 5 || nelem < 3) 
        throw mp3d_ex("The third argument must be: 3 <= numel(szX) <= 5.\n");

      // get szX
      mwSize szX[5];
      szX[3] = szX[4] = 1;
      for (int i = 0; i < nelem; ++i) szX[i] = (mwSize) ptr[i];
      
      // create the dX
      holder.dX.setMxArray( createVol5dZeros(szX, holder.dY.dt) );

      // reset the option beginning
      opt_beg = 3;
    } 
    else // issue the warning
      mexWarnMsgTxt("For bprop(), the calling method with 2 args:\n"
        "[...] = mex_maxpool3d(dzdY, ind, ...)\n"
        "is deprecated, because this could cause ambiguity when inferring input X size.\n"
        "Use the new one to specify the size for input X (or dzdX) explicitly:\n"
        "[...] = mex_maxpool3d(dzdY, ind, szX,...)\n");
  }

  // set options
  set_options(holder, opt_beg, ni, vi);

  // check validity
  check_padpool(holder);

  // create the desired worker and set the parameters
#ifdef WITH_GPUARRAY
  if (dt == xpuMxArrayTW::GPU)
    return new maxpool3d_gpu(holder);
  else
    return new maxpool3d_cpu(holder);
#else
  return new maxpool3d_cpu(holder);
#endif // WITH_GPUARRAY
}
Beispiel #3
0
void factory_mp3d_homebrew::set_stride(maxpool3d &h, mxArray const *pa )
{
  if ( !setCArray<mwSize, 3>(pa, h.stride) )
    throw mp3d_ex("The length of option stride must be 1 or 3.");
}
Beispiel #4
0
void maxpool3d::check_dY_ind()
{
  if (dY.getDevice() != ind.getDevice())
    throw mp3d_ex("In bprop(): dY and ind must be both gpuArray or CPU mxArray.\n");
}