예제 #1
0
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    if (nrhs != 4)
    {
        mexPrintf("Usage: alpha = LineSearch(prevScores, newScores, labels, lossType)\n");
        mexErrMsgTxt("Incorrect input format\n");
    }

#define mPrevScores (prhs[0])
#define mNewScores (prhs[1])
#define mLabels (prhs[2])
#define mLossType (prhs[3])

    if (nlhs != 1)
        mexErrMsgTxt("One output arg expected");

    char lossName[40];
    if ( mxGetString(mLossType, lossName, sizeof(lossName)) != 0 )
        mexErrMsgTxt("Error reading options.loss");

    SQB::LossType sqbLoss = SQB::ExpLoss;

    if (strcmp(lossName, "exploss") == 0)
        sqbLoss = SQB::ExpLoss;
    else if ( strcmp(lossName, "logloss") == 0 )
        sqbLoss = SQB::LogLoss;
    else if ( strcmp(lossName, "squaredloss") == 0 )
        sqbLoss = SQB::SquaredLoss;
    else
        mexErrMsgTxt("options.loss contains an invalid value");


    MatlabInputMatrix<WeightsType> pPrev( mPrevScores, 0, 1, "prevScores" );
    MatlabInputMatrix<WeightsType> pNew( mNewScores, pPrev.rows(), 1, "newScores" );
    MatlabInputMatrix<WeightsType> pLabels( mLabels, pPrev.rows(), 1, "labels" );

    // create mappings
    ArrayMapType prevMap( pPrev.data(), pPrev.rows(), pPrev.cols() );
    ArrayMapType newMap( pNew.data(), pNew.rows(), pNew.cols() );
    ArrayMapType labelsMap( pLabels.data(), pLabels.rows(), pLabels.cols() );



    SQB::LineSearch< ArrayType, ArrayMapType >  LS( prevMap, newMap, labelsMap, sqbLoss );

    WeightsType alpha = LS.run();

    MatlabOutputMatrix<WeightsType>   outMatrix( &plhs[0], 1, 1 );
    outMatrix.data()[0] = alpha;
}
예제 #2
0
void mexFunctionTrain(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    if (nrhs != 4)
    {
        mexPrintf("Usage: [model] = SQBTreesTrain( feats, labels, maxIters, options )\n");
        mexPrintf("\tfeats must be of type SINGLE\n");
        mexPrintf("\tOptions contains the following fields:\n");
        mexPrintf("\t\t.loss = 'exploss', 'squaredloss' or 'logloss'\n");
        mexPrintf("\t\t.shrinkageFactor = between 0 and 1. For instance, 0.02\n");
        mexPrintf("\t\t.subsamplingFactor = between 0 and 1. Initial suggestion: 0.5\n");
        mexPrintf("\t\t.maxTreeDepth >= 0. UINT32. Initial suggestion: 2\n");
        mexPrintf("\t\t.disableLineSearch, UINT32. If != 0 then line search is disabled and shrinkageFactor is used as the step length.\n");
        mexPrintf("\t\t.mtry, UINT32. Num of variables to search at each split, randomly chosen (as in Random Forests). Optional, otherwise it equals number of columns in feats.\n");
        mexPrintf("\t\t.randSeed. UINT32. If not provided, time is used to generate a seed.\n");
        mexPrintf("\t\t.verboseOutput, UINT32. If != 0 then verbose output is enabled (default = false if not specified).\n");
        mexErrMsgTxt("Incorrect input format\n");
    }

    if (nlhs != 1)
        mexErrMsgTxt("One output arg expected");

#define mFeats (prhs[0])
#define mLabels (prhs[1])
#define mMaxIters (prhs[2])
#define mOptions (prhs[3])

    MatlabInputMatrix<FeatsType> pFeats( mFeats, 0, 0, "feats" );
    MatlabInputMatrix<WeightsType> pLabels( mLabels, pFeats.rows(), 1, "labels" );
    MatlabInputMatrix<unsigned int> pMaxIters( mMaxIters, 1, 1, "maxiters" );

    const unsigned maxIters = pMaxIters.data()[0];

    TreeBoosterType TB;

    // we will use a random sampler
    SQB::TreeBoosterNaiveResampler< TreeBoosterType::ResamplerBaseObjectType::WeightsArrayType,
        TreeBoosterType::ResamplerBaseObjectType::LabelsArrayType >  resampler;

    TB.setResamplerObject( &resampler );

    // read options, a bit messy right now
    {
        mxArray *pLoss = mxGetField( mOptions, 0, "loss" );
        if (pLoss == NULL)  mexErrMsgTxt("options.loss not found");

        mxArray *pSF = mxGetField( mOptions, 0, "shrinkageFactor" );
        if (pSF == NULL)  mexErrMsgTxt("options.shrinkageFactor not found");

        mxArray *pSubSampFact = mxGetField( mOptions, 0, "subsamplingFactor" );
        if (pSubSampFact == NULL)  mexErrMsgTxt("options.subsamplingFactor not found");

        mxArray *pMaxDepth = mxGetField( mOptions, 0, "maxTreeDepth" );
        if (pMaxDepth == NULL)  mexErrMsgTxt("options.maxTreeDepth not found");

        mxArray *pRandSeed = mxGetField( mOptions, 0, "randSeed" );
        const bool usingCustomRandSeed = (pRandSeed != NULL);

        mxArray *pDisableLineSearch = mxGetField( mOptions, 0, "disableLineSearch" );
        const bool disableLineSearchSpecified = (pDisableLineSearch != NULL);

        mxArray *pVerboseOutput = mxGetField( mOptions, 0, "verboseOutput" );
        const bool verboseOutputSpecified = (pVerboseOutput != NULL);

        mxArray *pMTry = mxGetField( mOptions, 0, "mtry" );
        const bool mtrySpecified = (pMTry != NULL);

        char lossName[40];
        if ( mxGetString(pLoss, lossName, sizeof(lossName)) != 0 )
            mexErrMsgTxt("Error reading options.loss");

        if ( mxGetClassID(pSF) != mxDOUBLE_CLASS )
            mexErrMsgTxt("options.shrinkageFactor must be double");
        if ( mxGetNumberOfElements(pSF) != 1 )
            mexErrMsgTxt("options.shrinkageFactor must be a scalar");

        {
            const double sf = ((double *)mxGetData(pSF))[0];
            if ( sf <= 0 || sf > 1) mexErrMsgTxt("Shrinkage factor must be in (0,1]");
            TB.setShrinkageFactor( sf );
        }

        if (strcmp(lossName, "exploss") == 0)
            TB.setLoss( SQB::ExpLoss );
        else if ( strcmp(lossName, "logloss") == 0 )
            TB.setLoss( SQB::LogLoss );
        else if ( strcmp(lossName, "squaredloss") == 0 )
            TB.setLoss( SQB::SquaredLoss );
        else
            mexErrMsgTxt("options.loss contains an invalid value");

        if ( mxGetClassID(pSubSampFact) != mxDOUBLE_CLASS )
            mexErrMsgTxt("options.subsamplingFactor must be double");
        if ( mxGetNumberOfElements(pSubSampFact) != 1 )
            mexErrMsgTxt("options.subsamplingFactor must be a scalar");

        {
            const double ss = ((double *)mxGetData(pSubSampFact))[0];
            if ( ss <= 0 || ss > 1 )    mexErrMsgTxt("Subsampling factor must be in (0,1]");
            resampler.setResamplingFactor( ss );
        }

        if ( mxGetClassID(pMaxDepth) != mxUINT32_CLASS )
            mexErrMsgTxt("options.pMaxDepth must be UINT32");
        if ( mxGetNumberOfElements(pMaxDepth) != 1 )
            mexErrMsgTxt("options.maxDepth must be a scalar");

        {
            const unsigned maxDepth = ((unsigned int *)mxGetData(pMaxDepth))[0];
            if (maxDepth < 0)   // not gonna happen if it is unsigned
                mexErrMsgTxt("Minimum maxDepth is 0");
            TB.setMaxTreeDepth( maxDepth );
        }

        if ( disableLineSearchSpecified )
        {
            if ( mxGetClassID(pDisableLineSearch) != mxUINT32_CLASS )
                mexErrMsgTxt("options.disableLineSearch must be UINT32");

            if ( mxGetNumberOfElements(pDisableLineSearch) != 1 )
                mexErrMsgTxt("options.disableLineSearch must be a scalar");

            TB.setDisableLineSearch(((unsigned int *)mxGetData(pDisableLineSearch))[0] != 0);
        }
        else
            TB.setDisableLineSearch(false);

        if ( verboseOutputSpecified )
        {
            if ( mxGetClassID(pVerboseOutput) != mxUINT32_CLASS )
                mexErrMsgTxt("options.verboseOutput must be UINT32");

            if ( mxGetNumberOfElements(pVerboseOutput) != 1 )
                mexErrMsgTxt("options.verboseOutput must be a scalar");

            TB.setVerboseOutput(((unsigned int *)mxGetData(pVerboseOutput))[0] != 0);
        }
        else
            TB.setVerboseOutput(false);

        if ( mtrySpecified )
        {
            if ( mxGetClassID(pMTry) != mxUINT32_CLASS )
                mexErrMsgTxt("options.mtry must be UINT32");

            if ( mxGetNumberOfElements(pMTry) != 1 )
                mexErrMsgTxt("options.mtry must be a scalar");

            TB.setMTry(((unsigned int *)mxGetData(pMTry))[0]);
        }
        else
            TB.setDisableLineSearch(false);

        if (usingCustomRandSeed)
        {
            if ( mxGetClassID(pRandSeed) != mxUINT32_CLASS )
                mexErrMsgTxt("options.randSeed must be UINT32");

            if ( mxGetNumberOfElements(pRandSeed) != 1 )
                mexErrMsgTxt("options.randSeed must be a scalar");

            TB.setRandSeed( ((unsigned int *)mxGetData(pRandSeed))[0] );
        } else
        {
            TB.setRandSeed( time(NULL) );
        }
    }

    {
        // for now just copy the values
        gFeatArrayType feats = Eigen::Map< const gFeatArrayType >( pFeats.data(), pFeats.rows(), pFeats.cols() );
        TreeBoosterType::ResponseArrayType labels = Eigen::Map< const TreeBoosterType::ResponseArrayType >( pLabels.data(), pLabels.rows() );

        TB.printOptionsSummary();
        TB.learn( TreeBoosterType::SampleListType(feats),
                  TreeBoosterType::FeatureListType(feats),
                  TreeBoosterType::FeatureValueObjectType(feats),
                  TreeBoosterType::ClassifierResponseValueObjectType(labels),
                  maxIters );
        TB.printOptionsSummary();
    }

    plhs[0] = TB.saveToMatlab();

#undef mFeats
#undef mLabels
#undef mMaxIters
#undef mPredFeats
}
예제 #3
0
bool ConnectedComponents::execute(PlugInArgList* pInArgList, PlugInArgList* pOutArgList)
{
   if (pInArgList == NULL)
   {
      return false;
   }
   mProgress = ProgressTracker(pInArgList->getPlugInArgValue<Progress>(Executable::ProgressArg()),
      "Labeling connected components", "app", "{aa2169d0-9c0a-4d41-9f1d-9a9e83ecf32b}");
   mpView = pInArgList->getPlugInArgValue<SpatialDataView>(Executable::ViewArg());
   AoiElement* pAoi = pInArgList->getPlugInArgValue<AoiElement>("AOI");
   if (pAoi == NULL && mpView != NULL)
   {
      Layer* pLayer = mpView->getActiveLayer();
      if (pLayer == NULL)
      {
         std::vector<Layer*> layers;
         mpView->getLayerList()->getLayers(AOI_LAYER, layers);
         if (!layers.empty())
         {
            pLayer = layers.front();
         }
      }
      pAoi = pLayer == NULL ? NULL : dynamic_cast<AoiElement*>(pLayer->getDataElement());
   }
   const BitMask* mpBitmask = (pAoi == NULL) ? NULL : pAoi->getSelectedPoints();
   if (mpBitmask == NULL)
   {
      mProgress.report("Must specify an AOI.", 0, ERRORS, true);
      return false;
   }
   if (mpBitmask->isOutsideSelected())
   {
      mProgress.report("Infinite AOIs can not be processed.", 0, ERRORS, true);
      return false;
   }

   // Get the extents and create the output element
   int x1 = 0;
   int x2 = 0;
   int y1 = 0;
   int y2 = 0;
   mpBitmask->getMinimalBoundingBox(x1, y1, x2, y2);
   if (x1 > x2)
   {
      std::swap(x1, x2);
   }
   if (y1 > y2)
   {
      std::swap(y1, y2);
   }
   if (x1 < 0 || y1 < 0)
   {
      mProgress.report("Negative pixel locations are not supported and will be ignored.", 1, WARNING, true);
      x1 = std::max(x1, 0);
      y1 = std::max(y1, 0);
      x2 = std::max(x2, 0);
      y2 = std::max(y2, 0);
   }
   // Include a 1 pixel border so we include the edge pixels
   x1--;
   x2++;
   y1--;
   y2++;
   unsigned int width = x2 - x1 + 1;
   unsigned int height = y2 - y1 + 1;

   mXOffset = x1;
   mYOffset = y1;

   mpLabels = static_cast<RasterElement*>(
      Service<ModelServices>()->getElement("Blobs", TypeConverter::toString<RasterElement>(), pAoi));
   if (mpLabels != NULL)
   {
      if (!isBatch())
      {
         Service<DesktopServices>()->showSuppressibleMsgDlg(
            getName(), "The \"Blobs\" element exists and will be deleted.", MESSAGE_INFO,
            "ConnectedComponents::DeleteExisting");
      }
      Service<ModelServices>()->destroyElement(mpLabels);
      mpLabels = NULL;
   }
   mpLabels = RasterUtilities::createRasterElement("Blobs", height, width, INT2UBYTES, true, pAoi);
   if (mpLabels == NULL)
   {
      mProgress.report("Unable to create label element.", 0, ERRORS, true);
      return false;
   }
   ModelResource<RasterElement> pLabels(mpLabels);

   try
   {
      cv::Mat data(height, width, CV_8UC1, cv::Scalar(0));
      for (unsigned int y = 0; y < height; ++y)
      {
         mProgress.report("Reading AOI data", 10 * y / height, NORMAL);
         for (unsigned int x = 0; x < width; ++x)
         {
            if (mpBitmask->getPixel(x + mXOffset, y + mYOffset))
            {
               data.at<unsigned char>(y, x) = 255;
            }
         }
      }
      mProgress.report("Finding contours", 15, NORMAL);
      std::vector<std::vector<cv::Point> > contours;
      std::vector<cv::Vec4i> hierarchy;
      cv::findContours(data, contours, hierarchy, CV_RETR_TREE, CV_CHAIN_APPROX_SIMPLE);
      cv::Mat labels(height, width, CV_16UC1, mpLabels->getRawData());
      mProgress.report("Filling blobs", 50, NORMAL);
      unsigned short lastLabel = 0;
      fillContours(labels, contours, hierarchy, lastLabel, 0, 0);

      // create a pseudocolor layer for display
      mProgress.report("Displaying results", 90, NORMAL);
      mpLabels->updateData();
      if (!createPseudocolor(lastLabel))
      {
         mProgress.report("Unable to create blob layer", 0, ERRORS, true);
         return false;
      }

      // add blob count to the metadata
      DynamicObject* pMeta = pLabels->getMetadata();
      VERIFY(pMeta);
      unsigned int numBlobs = static_cast<unsigned int>(lastLabel);
      pMeta->setAttribute("BlobCount", numBlobs);
      if (numBlobs == 0 && !isBatch())
      {
         // Inform the user that there were no blobs so they don't think there was an
         // error running the algorithm. No need to do this in batch since this is
         // represented in the metadata already.
         mProgress.report("No blobs were found.", 95, WARNING);
      }
      // update the output arg list
      if (pOutArgList != NULL)
      {
         pOutArgList->setPlugInArgValue("Blobs", pLabels.get());
         pOutArgList->setPlugInArgValue("Number of Blobs", &numBlobs);
      }
   }
   catch(const cv::Exception& exc)
   {
      mProgress.report(exc.what(), 0, ERRORS, true);
      return false;
   }

   pLabels.release();
   mProgress.report("Labeling connected components", 100, NORMAL);
   mProgress.upALevel();
   return true;
}