void mexFunctionTrain(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { if (nrhs != 4) { mexPrintf("Usage: [model] = SQBTreesTrain( feats, labels, maxIters, options )\n"); mexPrintf("\tfeats must be of type SINGLE\n"); mexPrintf("\tOptions contains the following fields:\n"); mexPrintf("\t\t.loss = 'exploss', 'squaredloss' or 'logloss'\n"); mexPrintf("\t\t.shrinkageFactor = between 0 and 1. For instance, 0.02\n"); mexPrintf("\t\t.subsamplingFactor = between 0 and 1. Initial suggestion: 0.5\n"); mexPrintf("\t\t.maxTreeDepth >= 0. UINT32. Initial suggestion: 2\n"); mexPrintf("\t\t.disableLineSearch, UINT32. If != 0 then line search is disabled and shrinkageFactor is used as the step length.\n"); mexPrintf("\t\t.mtry, UINT32. Num of variables to search at each split, randomly chosen (as in Random Forests). Optional, otherwise it equals number of columns in feats.\n"); mexPrintf("\t\t.randSeed. UINT32. If not provided, time is used to generate a seed.\n"); mexPrintf("\t\t.verboseOutput, UINT32. If != 0 then verbose output is enabled (default = false if not specified).\n"); mexErrMsgTxt("Incorrect input format\n"); } if (nlhs != 1) mexErrMsgTxt("One output arg expected"); #define mFeats (prhs[0]) #define mLabels (prhs[1]) #define mMaxIters (prhs[2]) #define mOptions (prhs[3]) MatlabInputMatrix<FeatsType> pFeats( mFeats, 0, 0, "feats" ); MatlabInputMatrix<WeightsType> pLabels( mLabels, pFeats.rows(), 1, "labels" ); MatlabInputMatrix<unsigned int> pMaxIters( mMaxIters, 1, 1, "maxiters" ); const unsigned maxIters = pMaxIters.data()[0]; TreeBoosterType TB; // we will use a random sampler SQB::TreeBoosterNaiveResampler< TreeBoosterType::ResamplerBaseObjectType::WeightsArrayType, TreeBoosterType::ResamplerBaseObjectType::LabelsArrayType > resampler; TB.setResamplerObject( &resampler ); // read options, a bit messy right now { mxArray *pLoss = mxGetField( mOptions, 0, "loss" ); if (pLoss == NULL) mexErrMsgTxt("options.loss not found"); mxArray *pSF = mxGetField( mOptions, 0, "shrinkageFactor" ); if (pSF == NULL) mexErrMsgTxt("options.shrinkageFactor not found"); mxArray *pSubSampFact = mxGetField( mOptions, 0, "subsamplingFactor" ); if (pSubSampFact == NULL) mexErrMsgTxt("options.subsamplingFactor not found"); mxArray *pMaxDepth = mxGetField( mOptions, 0, "maxTreeDepth" ); if (pMaxDepth == NULL) mexErrMsgTxt("options.maxTreeDepth not found"); mxArray *pRandSeed = mxGetField( mOptions, 0, "randSeed" ); const bool usingCustomRandSeed = (pRandSeed != NULL); mxArray *pDisableLineSearch = mxGetField( mOptions, 0, "disableLineSearch" ); const bool disableLineSearchSpecified = (pDisableLineSearch != NULL); mxArray *pVerboseOutput = mxGetField( mOptions, 0, "verboseOutput" ); const bool verboseOutputSpecified = (pVerboseOutput != NULL); mxArray *pMTry = mxGetField( mOptions, 0, "mtry" ); const bool mtrySpecified = (pMTry != NULL); char lossName[40]; if ( mxGetString(pLoss, lossName, sizeof(lossName)) != 0 ) mexErrMsgTxt("Error reading options.loss"); if ( mxGetClassID(pSF) != mxDOUBLE_CLASS ) mexErrMsgTxt("options.shrinkageFactor must be double"); if ( mxGetNumberOfElements(pSF) != 1 ) mexErrMsgTxt("options.shrinkageFactor must be a scalar"); { const double sf = ((double *)mxGetData(pSF))[0]; if ( sf <= 0 || sf > 1) mexErrMsgTxt("Shrinkage factor must be in (0,1]"); TB.setShrinkageFactor( sf ); } if (strcmp(lossName, "exploss") == 0) TB.setLoss( SQB::ExpLoss ); else if ( strcmp(lossName, "logloss") == 0 ) TB.setLoss( SQB::LogLoss ); else if ( strcmp(lossName, "squaredloss") == 0 ) TB.setLoss( SQB::SquaredLoss ); else mexErrMsgTxt("options.loss contains an invalid value"); if ( mxGetClassID(pSubSampFact) != mxDOUBLE_CLASS ) mexErrMsgTxt("options.subsamplingFactor must be double"); if ( mxGetNumberOfElements(pSubSampFact) != 1 ) mexErrMsgTxt("options.subsamplingFactor must be a scalar"); { const double ss = ((double *)mxGetData(pSubSampFact))[0]; if ( ss <= 0 || ss > 1 ) mexErrMsgTxt("Subsampling factor must be in (0,1]"); resampler.setResamplingFactor( ss ); } if ( mxGetClassID(pMaxDepth) != mxUINT32_CLASS ) mexErrMsgTxt("options.pMaxDepth must be UINT32"); if ( mxGetNumberOfElements(pMaxDepth) != 1 ) mexErrMsgTxt("options.maxDepth must be a scalar"); { const unsigned maxDepth = ((unsigned int *)mxGetData(pMaxDepth))[0]; if (maxDepth < 0) // not gonna happen if it is unsigned mexErrMsgTxt("Minimum maxDepth is 0"); TB.setMaxTreeDepth( maxDepth ); } if ( disableLineSearchSpecified ) { if ( mxGetClassID(pDisableLineSearch) != mxUINT32_CLASS ) mexErrMsgTxt("options.disableLineSearch must be UINT32"); if ( mxGetNumberOfElements(pDisableLineSearch) != 1 ) mexErrMsgTxt("options.disableLineSearch must be a scalar"); TB.setDisableLineSearch(((unsigned int *)mxGetData(pDisableLineSearch))[0] != 0); } else TB.setDisableLineSearch(false); if ( verboseOutputSpecified ) { if ( mxGetClassID(pVerboseOutput) != mxUINT32_CLASS ) mexErrMsgTxt("options.verboseOutput must be UINT32"); if ( mxGetNumberOfElements(pVerboseOutput) != 1 ) mexErrMsgTxt("options.verboseOutput must be a scalar"); TB.setVerboseOutput(((unsigned int *)mxGetData(pVerboseOutput))[0] != 0); } else TB.setVerboseOutput(false); if ( mtrySpecified ) { if ( mxGetClassID(pMTry) != mxUINT32_CLASS ) mexErrMsgTxt("options.mtry must be UINT32"); if ( mxGetNumberOfElements(pMTry) != 1 ) mexErrMsgTxt("options.mtry must be a scalar"); TB.setMTry(((unsigned int *)mxGetData(pMTry))[0]); } else TB.setDisableLineSearch(false); if (usingCustomRandSeed) { if ( mxGetClassID(pRandSeed) != mxUINT32_CLASS ) mexErrMsgTxt("options.randSeed must be UINT32"); if ( mxGetNumberOfElements(pRandSeed) != 1 ) mexErrMsgTxt("options.randSeed must be a scalar"); TB.setRandSeed( ((unsigned int *)mxGetData(pRandSeed))[0] ); } else { TB.setRandSeed( time(NULL) ); } } { // for now just copy the values gFeatArrayType feats = Eigen::Map< const gFeatArrayType >( pFeats.data(), pFeats.rows(), pFeats.cols() ); TreeBoosterType::ResponseArrayType labels = Eigen::Map< const TreeBoosterType::ResponseArrayType >( pLabels.data(), pLabels.rows() ); TB.printOptionsSummary(); TB.learn( TreeBoosterType::SampleListType(feats), TreeBoosterType::FeatureListType(feats), TreeBoosterType::FeatureValueObjectType(feats), TreeBoosterType::ClassifierResponseValueObjectType(labels), maxIters ); TB.printOptionsSummary(); } plhs[0] = TB.saveToMatlab(); #undef mFeats #undef mLabels #undef mMaxIters #undef mPredFeats }
void mexFunctionTrain(TreeBoosterType &TB/*int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]*/) { // if (nrhs != 4) // { // mexPrintf("Usage: [model] = SQBTreesTrain( feats, labels, maxIters, options )\n"); // mexPrintf("\tfeats must be of type SINGLE\n"); // mexPrintf("\tOptions contains the following fields:\n"); // mexPrintf("\t\t.loss = 'exploss', 'squaredloss' or 'logloss'\n"); // mexPrintf("\t\t.shrinkageFactor = between 0 and 1. For instance, 0.02\n"); // mexPrintf("\t\t.subsamplingFactor = between 0 and 1. Initial suggestion: 0.5\n"); // mexPrintf("\t\t.maxTreeDepth >= 0. UINT32. Initial suggestion: 2\n"); // mexPrintf("\t\t.disableLineSearch, UINT32. If != 0 then line search is disabled and shrinkageFactor is used as the step length.\n"); // mexPrintf("\t\t.mtry, UINT32. Num of variables to search at each split, randomly chosen (as in Random Forests). Optional, otherwise it equals number of columns in feats.\n"); // mexPrintf("\t\t.randSeed. UINT32. If not provided, time is used to generate a seed.\n"); // mexPrintf("\t\t.verboseOutput, UINT32. If != 0 then verbose output is enabled (default = false if not specified).\n"); // mexErrMsgTxt("Incorrect input format\n"); // } // if (nlhs != 1) // mexErrMsgTxt("One output arg expected"); // #define mFeats (prhs[0]) // #define mLabels (prhs[1]) // #define mMaxIters (prhs[2]) // #define mOptions (prhs[3]) // MatlabInputMatrix<FeatsType> pFeats( mFeats, 0, 0, "feats" ); // MatlabInputMatrix<WeightsType> pLabels( mLabels, pFeats.rows(), 1, "labels" ); // MatlabInputMatrix<unsigned int> pMaxIters( mMaxIters, 1, 1, "maxiters" ); // const unsigned maxIters = pMaxIters.data()[0]; // TreeBoosterType TB; // we will use a random sampler SQB::TreeBoosterNaiveResampler< TreeBoosterType::ResamplerBaseObjectType::WeightsArrayType, TreeBoosterType::ResamplerBaseObjectType::LabelsArrayType > resampler; TB.setResamplerObject( &resampler ); // // read options, a bit messy right now // { // mxArray *pLoss = mxGetField( mOptions, 0, "loss" ); // if (pLoss == NULL) mexErrMsgTxt("options.loss not found"); // mxArray *pSF = mxGetField( mOptions, 0, "shrinkageFactor" ); // if (pSF == NULL) mexErrMsgTxt("options.shrinkageFactor not found"); // mxArray *pSubSampFact = mxGetField( mOptions, 0, "subsamplingFactor" ); // if (pSubSampFact == NULL) mexErrMsgTxt("options.subsamplingFactor not found"); // mxArray *pMaxDepth = mxGetField( mOptions, 0, "maxTreeDepth" ); // if (pMaxDepth == NULL) mexErrMsgTxt("options.maxTreeDepth not found"); // mxArray *pRandSeed = mxGetField( mOptions, 0, "randSeed" ); // const bool usingCustomRandSeed = (pRandSeed != NULL); // mxArray *pDisableLineSearch = mxGetField( mOptions, 0, "disableLineSearch" ); // const bool disableLineSearchSpecified = (pDisableLineSearch != NULL); // mxArray *pVerboseOutput = mxGetField( mOptions, 0, "verboseOutput" ); // const bool verboseOutputSpecified = (pVerboseOutput != NULL); // mxArray *pMTry = mxGetField( mOptions, 0, "mtry" ); // const bool mtrySpecified = (pMTry != NULL); // char lossName[40]; // if ( mxGetString(pLoss, lossName, sizeof(lossName)) != 0 ) // mexErrMsgTxt("Error reading options.loss"); // if ( mxGetClassID(pSF) != mxDOUBLE_CLASS ) // mexErrMsgTxt("options.shrinkageFactor must be double"); // if ( mxGetNumberOfElements(pSF) != 1 ) // mexErrMsgTxt("options.shrinkageFactor must be a scalar"); // { // const double sf = ((double *)mxGetData(pSF))[0]; // if ( sf <= 0 || sf > 1) mexErrMsgTxt("Shrinkage factor must be in (0,1]"); // TB.setShrinkageFactor( sf ); // } // if (strcmp(lossName, "exploss") == 0) // TB.setLoss( SQB::ExpLoss ); // else if ( strcmp(lossName, "logloss") == 0 ) // TB.setLoss( SQB::LogLoss ); // else if ( strcmp(lossName, "squaredloss") == 0 ) // TB.setLoss( SQB::SquaredLoss ); // else // mexErrMsgTxt("options.loss contains an invalid value"); // if ( mxGetClassID(pSubSampFact) != mxDOUBLE_CLASS ) // mexErrMsgTxt("options.subsamplingFactor must be double"); // if ( mxGetNumberOfElements(pSubSampFact) != 1 ) // mexErrMsgTxt("options.subsamplingFactor must be a scalar"); // { // const double ss = ((double *)mxGetData(pSubSampFact))[0]; // if ( ss <= 0 || ss > 1 ) mexErrMsgTxt("Subsampling factor must be in (0,1]"); // resampler.setResamplingFactor( ss ); // } // if ( mxGetClassID(pMaxDepth) != mxUINT32_CLASS ) // mexErrMsgTxt("options.pMaxDepth must be UINT32"); // if ( mxGetNumberOfElements(pMaxDepth) != 1 ) // mexErrMsgTxt("options.maxDepth must be a scalar"); // { // const unsigned maxDepth = ((unsigned int *)mxGetData(pMaxDepth))[0]; // if (maxDepth < 0) // not gonna happen if it is unsigned // mexErrMsgTxt("Minimum maxDepth is 0"); // TB.setMaxTreeDepth( maxDepth ); // } // if ( disableLineSearchSpecified ) // { // if ( mxGetClassID(pDisableLineSearch) != mxUINT32_CLASS ) // mexErrMsgTxt("options.disableLineSearch must be UINT32"); // if ( mxGetNumberOfElements(pDisableLineSearch) != 1 ) // mexErrMsgTxt("options.disableLineSearch must be a scalar"); // TB.setDisableLineSearch(((unsigned int *)mxGetData(pDisableLineSearch))[0] != 0); // } // else // TB.setDisableLineSearch(false); // if ( verboseOutputSpecified ) // { // if ( mxGetClassID(pVerboseOutput) != mxUINT32_CLASS ) // mexErrMsgTxt("options.verboseOutput must be UINT32"); // if ( mxGetNumberOfElements(pVerboseOutput) != 1 ) // mexErrMsgTxt("options.verboseOutput must be a scalar"); // TB.setVerboseOutput(((unsigned int *)mxGetData(pVerboseOutput))[0] != 0); // } // else // TB.setVerboseOutput(false); // if ( mtrySpecified ) // { // if ( mxGetClassID(pMTry) != mxUINT32_CLASS ) // mexErrMsgTxt("options.mtry must be UINT32"); // if ( mxGetNumberOfElements(pMTry) != 1 ) // mexErrMsgTxt("options.mtry must be a scalar"); // TB.setMTry(((unsigned int *)mxGetData(pMTry))[0]); // } // else // TB.setDisableLineSearch(false); // if (usingCustomRandSeed) // { // if ( mxGetClassID(pRandSeed) != mxUINT32_CLASS ) // mexErrMsgTxt("options.randSeed must be UINT32"); // if ( mxGetNumberOfElements(pRandSeed) != 1 ) // mexErrMsgTxt("options.randSeed must be a scalar"); // TB.setRandSeed( ((unsigned int *)mxGetData(pRandSeed))[0] ); // } else // { // TB.setRandSeed( time(NULL) ); // } // } { //pFeats.data() returns const FeatsType* //pFeats.rows() and pFeats.cols() return unsigned int //pLabels.data() returns SQB::TreeBoosterWeightsType (i.e. double) //pLabels.rows() returns unsigned int // // for now just copy the values // gFeatArrayType feats = Eigen::Map< const gFeatArrayType >( pFeats.data(), pFeats.rows(), pFeats.cols() ); // TreeBoosterType::ResponseArrayType labels = Eigen::Map< const TreeBoosterType::ResponseArrayType >( pLabels.data(), pLabels.rows() ); // TB.printOptionsSummary(); // TB.learn( TreeBoosterType::SampleListType(feats), // TreeBoosterType::FeatureListType(feats), // TreeBoosterType::FeatureValueObjectType(feats), // TreeBoosterType::ClassifierResponseValueObjectType(labels), // maxIters ); // TB.printOptionsSummary(); const unsigned maxIters = 200; // for now just copy the values FeatsType featuresArray[] = {3, 1, 2, 6, 3, 4, -1, -2, -4, -2, -3, -1 }; unsigned int featuresColsNo = 3; unsigned int featuresRowsNo = 4; WeightsType labelsArray[] = {1, 1, -1, -1}; unsigned int labelsRowsNo = featuresRowsNo; gFeatArrayType feats = Eigen::Map< const gFeatArrayType >( featuresArray, featuresRowsNo, featuresColsNo ); TreeBoosterType::ResponseArrayType labels = Eigen::Map< const TreeBoosterType::ResponseArrayType >( labelsArray, labelsRowsNo ); TB.printOptionsSummary(); TB.learn( TreeBoosterType::SampleListType(feats), TreeBoosterType::FeatureListType(feats), TreeBoosterType::FeatureValueObjectType(feats), TreeBoosterType::ClassifierResponseValueObjectType(labels), maxIters ); std::cout << "Hu!" << std::endl << std::flush; TB.printOptionsSummary(); } // libconfig::Config cfg; // libconfig::Setting &root = cfg.getRoot(); // // Add some settings to the configuration. // libconfig::Setting &address = root.add("regressor", libconfig::Setting::TypeList); libconfig::Config cfg; libconfig::Setting &root = cfg.getRoot(); // // Add some settings to the configuration. // libconfig::Setting &address = root.add("address", libconfig::Setting::TypeGroup); // address.add("street", libconfig::Setting::TypeString) = "1 Woz Way"; // address.add("city", libconfig::Setting::TypeString) = "San Jose"; // address.add("state", libconfig::Setting::TypeString) = "CA"; // address.add("zip", libconfig::Setting::TypeInt) = 95110; // libconfig::Setting &array = root.add("array", libconfig::Setting::TypeArray); // for(int i = 0; i < 10; ++i) // array.add(libconfig::Setting::TypeInt) = 10 * i; // root.add(libconfig::Setting::TypeList); libconfig::Setting ®ressor = root.add("regressor", libconfig::Setting::TypeList); TB.saveToLibconfig(regressor); static const char *output_file = "/cvlabdata1/home/pglowack/Work/Vaa3D-BuiltWithDefaultScripts/vaa3d_tools/" "bigneuron_ported/AmosSironi_PrzemyslawGlowacki/SQBTree_plugin/aaaaaaaa.cfg"; // Write out the new configuration. try { cfg.writeFile(output_file); std::cerr << "New configuration successfully written to: " << output_file << std::endl; } catch(const libconfig::FileIOException &fioex) { std::cerr << "I/O error while writing file: " << output_file << std::endl; //return(EXIT_FAILURE); } // plhs[0] = TB.saveToMatlab(); //#undef mFeats //#undef mLabels //#undef mMaxIters //#undef mPredFeats }