void factory_mp3d_homebrew::set_pad(maxpool3d &h, mxArray const *pa ) { if ( !setCArray<mwSize, 6>(pa, h.pad) ) throw mp3d_ex("The length of option pad must be 1 or 6."); }
maxpool3d* factory_mp3d_homebrew::parse_and_create(int no, mxArray *vo[], int ni, mxArray const *vi[]) { if (ni < 1) throw mp3d_ex("Too few input arguments."); // fprop or bprop? maxpool3d holder; int opt_beg = -1; xpuMxArrayTW::DEV_TYPE dt; if (no == 2) { // fprop holder.X.setMxArray( (mxArray*) vi[0] ); // we won't change it dt = holder.X.getDevice(); if ( ni < 1 || (holder.X.getElemType() != mxSINGLE_CLASS) ) throw mp3d_ex("For fprop(), there should be at least one input, X, of SINGLE type," "be all gpuArray or be all mxArray.\n"); holder.ct = maxpool3d::FPROP; opt_beg = 1; } else if (no == 1) { // bprop holder.dY.setMxArray( (mxArray*) vi[0]); holder.ind.setMxArray( (mxArray*) vi[1]); dt = holder.dY.getDevice(); if ( ni < 2 || holder.dY.getElemType() != mxSINGLE_CLASS || holder.ind.getElemType() != mxINT32_CLASS) throw mp3d_ex("For bprop(): there should be at least 3 arguments, dzdY, ind.\n" "The dzdY must be SINGLE, the max index ind must be int32," "they should be both gpuArray or be both mxArray.\n"); holder.ct = maxpool3d::BPROP; opt_beg = 2; } else { throw mp3d_ex("Unrecognized arguments/way of calling. " "The output should be either [Y, ind] (fprop) or ind (bprop). "); } // if bprop: create dX if szX is provided args:(dzdY, ind, szX) if (holder.ct == maxpool3d::BPROP) { if (ni >= 3 && !mxIsChar(vi[2]) ) { // szX provided, check it if (!mxIsDouble(vi[2])) mexErrMsgTxt("setCArray: pa must be double matrix\n"); double *ptr = (double*)mxGetData(vi[2]); mwSize nelem = mxGetNumberOfElements(vi[2]); if (nelem > 5 || nelem < 3) throw mp3d_ex("The third argument must be: 3 <= numel(szX) <= 5.\n"); // get szX mwSize szX[5]; szX[3] = szX[4] = 1; for (int i = 0; i < nelem; ++i) szX[i] = (mwSize) ptr[i]; // create the dX holder.dX.setMxArray( createVol5dZeros(szX, holder.dY.dt) ); // reset the option beginning opt_beg = 3; } else // issue the warning mexWarnMsgTxt("For bprop(), the calling method with 2 args:\n" "[...] = mex_maxpool3d(dzdY, ind, ...)\n" "is deprecated, because this could cause ambiguity when inferring input X size.\n" "Use the new one to specify the size for input X (or dzdX) explicitly:\n" "[...] = mex_maxpool3d(dzdY, ind, szX,...)\n"); } // set options set_options(holder, opt_beg, ni, vi); // check validity check_padpool(holder); // create the desired worker and set the parameters #ifdef WITH_GPUARRAY if (dt == xpuMxArrayTW::GPU) return new maxpool3d_gpu(holder); else return new maxpool3d_cpu(holder); #else return new maxpool3d_cpu(holder); #endif // WITH_GPUARRAY }
void factory_mp3d_homebrew::set_stride(maxpool3d &h, mxArray const *pa ) { if ( !setCArray<mwSize, 3>(pa, h.stride) ) throw mp3d_ex("The length of option stride must be 1 or 3."); }
void maxpool3d::check_dY_ind() { if (dY.getDevice() != ind.getDevice()) throw mp3d_ex("In bprop(): dY and ind must be both gpuArray or CPU mxArray.\n"); }