void mexFunctionTrain(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { if (nrhs != 4) { mexPrintf("Usage: [model] = SQBTreesTrain( feats, labels, maxIters, options )\n"); mexPrintf("\tfeats must be of type SINGLE\n"); mexPrintf("\tOptions contains the following fields:\n"); mexPrintf("\t\t.loss = 'exploss', 'squaredloss' or 'logloss'\n"); mexPrintf("\t\t.shrinkageFactor = between 0 and 1. For instance, 0.02\n"); mexPrintf("\t\t.subsamplingFactor = between 0 and 1. Initial suggestion: 0.5\n"); mexPrintf("\t\t.maxTreeDepth >= 0. UINT32. Initial suggestion: 2\n"); mexPrintf("\t\t.disableLineSearch, UINT32. If != 0 then line search is disabled and shrinkageFactor is used as the step length.\n"); mexPrintf("\t\t.mtry, UINT32. Num of variables to search at each split, randomly chosen (as in Random Forests). Optional, otherwise it equals number of columns in feats.\n"); mexPrintf("\t\t.randSeed. UINT32. If not provided, time is used to generate a seed.\n"); mexPrintf("\t\t.verboseOutput, UINT32. If != 0 then verbose output is enabled (default = false if not specified).\n"); mexErrMsgTxt("Incorrect input format\n"); } if (nlhs != 1) mexErrMsgTxt("One output arg expected"); #define mFeats (prhs[0]) #define mLabels (prhs[1]) #define mMaxIters (prhs[2]) #define mOptions (prhs[3]) MatlabInputMatrix<FeatsType> pFeats( mFeats, 0, 0, "feats" ); MatlabInputMatrix<WeightsType> pLabels( mLabels, pFeats.rows(), 1, "labels" ); MatlabInputMatrix<unsigned int> pMaxIters( mMaxIters, 1, 1, "maxiters" ); const unsigned maxIters = pMaxIters.data()[0]; TreeBoosterType TB; // we will use a random sampler SQB::TreeBoosterNaiveResampler< TreeBoosterType::ResamplerBaseObjectType::WeightsArrayType, TreeBoosterType::ResamplerBaseObjectType::LabelsArrayType > resampler; TB.setResamplerObject( &resampler ); // read options, a bit messy right now { mxArray *pLoss = mxGetField( mOptions, 0, "loss" ); if (pLoss == NULL) mexErrMsgTxt("options.loss not found"); mxArray *pSF = mxGetField( mOptions, 0, "shrinkageFactor" ); if (pSF == NULL) mexErrMsgTxt("options.shrinkageFactor not found"); mxArray *pSubSampFact = mxGetField( mOptions, 0, "subsamplingFactor" ); if (pSubSampFact == NULL) mexErrMsgTxt("options.subsamplingFactor not found"); mxArray *pMaxDepth = mxGetField( mOptions, 0, "maxTreeDepth" ); if (pMaxDepth == NULL) mexErrMsgTxt("options.maxTreeDepth not found"); mxArray *pRandSeed = mxGetField( mOptions, 0, "randSeed" ); const bool usingCustomRandSeed = (pRandSeed != NULL); mxArray *pDisableLineSearch = mxGetField( mOptions, 0, "disableLineSearch" ); const bool disableLineSearchSpecified = (pDisableLineSearch != NULL); mxArray *pVerboseOutput = mxGetField( mOptions, 0, "verboseOutput" ); const bool verboseOutputSpecified = (pVerboseOutput != NULL); mxArray *pMTry = mxGetField( mOptions, 0, "mtry" ); const bool mtrySpecified = (pMTry != NULL); char lossName[40]; if ( mxGetString(pLoss, lossName, sizeof(lossName)) != 0 ) mexErrMsgTxt("Error reading options.loss"); if ( mxGetClassID(pSF) != mxDOUBLE_CLASS ) mexErrMsgTxt("options.shrinkageFactor must be double"); if ( mxGetNumberOfElements(pSF) != 1 ) mexErrMsgTxt("options.shrinkageFactor must be a scalar"); { const double sf = ((double *)mxGetData(pSF))[0]; if ( sf <= 0 || sf > 1) mexErrMsgTxt("Shrinkage factor must be in (0,1]"); TB.setShrinkageFactor( sf ); } if (strcmp(lossName, "exploss") == 0) TB.setLoss( SQB::ExpLoss ); else if ( strcmp(lossName, "logloss") == 0 ) TB.setLoss( SQB::LogLoss ); else if ( strcmp(lossName, "squaredloss") == 0 ) TB.setLoss( SQB::SquaredLoss ); else mexErrMsgTxt("options.loss contains an invalid value"); if ( mxGetClassID(pSubSampFact) != mxDOUBLE_CLASS ) mexErrMsgTxt("options.subsamplingFactor must be double"); if ( mxGetNumberOfElements(pSubSampFact) != 1 ) mexErrMsgTxt("options.subsamplingFactor must be a scalar"); { const double ss = ((double *)mxGetData(pSubSampFact))[0]; if ( ss <= 0 || ss > 1 ) mexErrMsgTxt("Subsampling factor must be in (0,1]"); resampler.setResamplingFactor( ss ); } if ( mxGetClassID(pMaxDepth) != mxUINT32_CLASS ) mexErrMsgTxt("options.pMaxDepth must be UINT32"); if ( mxGetNumberOfElements(pMaxDepth) != 1 ) mexErrMsgTxt("options.maxDepth must be a scalar"); { const unsigned maxDepth = ((unsigned int *)mxGetData(pMaxDepth))[0]; if (maxDepth < 0) // not gonna happen if it is unsigned mexErrMsgTxt("Minimum maxDepth is 0"); TB.setMaxTreeDepth( maxDepth ); } if ( disableLineSearchSpecified ) { if ( mxGetClassID(pDisableLineSearch) != mxUINT32_CLASS ) mexErrMsgTxt("options.disableLineSearch must be UINT32"); if ( mxGetNumberOfElements(pDisableLineSearch) != 1 ) mexErrMsgTxt("options.disableLineSearch must be a scalar"); TB.setDisableLineSearch(((unsigned int *)mxGetData(pDisableLineSearch))[0] != 0); } else TB.setDisableLineSearch(false); if ( verboseOutputSpecified ) { if ( mxGetClassID(pVerboseOutput) != mxUINT32_CLASS ) mexErrMsgTxt("options.verboseOutput must be UINT32"); if ( mxGetNumberOfElements(pVerboseOutput) != 1 ) mexErrMsgTxt("options.verboseOutput must be a scalar"); TB.setVerboseOutput(((unsigned int *)mxGetData(pVerboseOutput))[0] != 0); } else TB.setVerboseOutput(false); if ( mtrySpecified ) { if ( mxGetClassID(pMTry) != mxUINT32_CLASS ) mexErrMsgTxt("options.mtry must be UINT32"); if ( mxGetNumberOfElements(pMTry) != 1 ) mexErrMsgTxt("options.mtry must be a scalar"); TB.setMTry(((unsigned int *)mxGetData(pMTry))[0]); } else TB.setDisableLineSearch(false); if (usingCustomRandSeed) { if ( mxGetClassID(pRandSeed) != mxUINT32_CLASS ) mexErrMsgTxt("options.randSeed must be UINT32"); if ( mxGetNumberOfElements(pRandSeed) != 1 ) mexErrMsgTxt("options.randSeed must be a scalar"); TB.setRandSeed( ((unsigned int *)mxGetData(pRandSeed))[0] ); } else { TB.setRandSeed( time(NULL) ); } } { // for now just copy the values gFeatArrayType feats = Eigen::Map< const gFeatArrayType >( pFeats.data(), pFeats.rows(), pFeats.cols() ); TreeBoosterType::ResponseArrayType labels = Eigen::Map< const TreeBoosterType::ResponseArrayType >( pLabels.data(), pLabels.rows() ); TB.printOptionsSummary(); TB.learn( TreeBoosterType::SampleListType(feats), TreeBoosterType::FeatureListType(feats), TreeBoosterType::FeatureValueObjectType(feats), TreeBoosterType::ClassifierResponseValueObjectType(labels), maxIters ); TB.printOptionsSummary(); } plhs[0] = TB.saveToMatlab(); #undef mFeats #undef mLabels #undef mMaxIters #undef mPredFeats }