// Computes an Lidfaces model with images in src and corresponding labels
// in labels.
void Lidfaces::train(cv::InputArrayOfArrays src, cv::InputArray labels)
{
    std::vector<std::vector<cv::KeyPoint> > allKeyPoints;
    cv::Mat descriptors;

    // Get SIFT keypoints and LID descriptors
    detectKeypointsAndDescriptors(src, allKeyPoints, descriptors);

    // kmeans function requires points to be CV_32F
    descriptors.convertTo(descriptors, CV_32FC1);

    // Do k-means clustering
    const int CLUSTER_COUNT = params::lidFace::clustersAsPercentageOfKeypoints*descriptors.rows;
    cv::Mat histogramLabels;

    // This function populates histogram bin labels
    // The nth element of histogramLabels is an integer which represents the cluster that the
    // nth element of allKeyPoints is a member of.
    kmeans(
        descriptors, // The points we are clustering are the descriptors
        CLUSTER_COUNT, // The number of clusters (K)
        histogramLabels, // The label of the corresponding keypoint
        params::kmeans::termCriteria,
        params::kmeans::attempts,
        params::kmeans::flags,
        mCenters);

    // Convert to single channel 32 bit float as the matrix needs to be in a form supported
    // by calcHist
    histogramLabels.convertTo(histogramLabels, CV_32FC1);

    // We end up with a histogram for each image
    const size_t NUM_IMAGES = getSize(src);
    std::vector<cv::Mat> hists(NUM_IMAGES);
    // mCodebook.resize(NUM_IMAGES);

    // The histogramLabels vector contains ALL the points from EVERY image. We need to split
    // it up into groups of points for each image.
    // Because there are the same number of points in each image, and the points were put
    // into the labels vector in order, we can simply divide the labels vector evenly to get
    // the individual image's points.
    std::vector<cv::Mat> separatedLabels;
    for (unsigned int i = 0, startRow = 0; i < NUM_IMAGES; ++i)
    {
        separatedLabels.push_back(
            histogramLabels.rowRange(
                startRow,
                startRow + allKeyPoints[i].size()));
        startRow += allKeyPoints[i].size();
    }

    // Populate the hists vector
    generateHistograms(hists, separatedLabels, CLUSTER_COUNT);

    // Make the magnitude of each histogram equal to 1
    normalizeHistograms(hists);

    mCodebook = hists;
    mLabels = labels.getMat();
}
Ejemplo n.º 2
0
RandomForestImage::RandomForestImage(const std::vector<std::string>& treeFiles,
                    const std::vector<int>& deviceIds,
                    const AccelerationMode accelerationMode,
                    const double histogramBias)
 : configuration(), ensemble(treeFiles.size()),
   m_predictionAllocator(boost::make_shared<cuv::pooled_cuda_allocator>())
{

    if (treeFiles.empty()) {
        throw std::runtime_error("cannot construct empty forest");
    }

    std::vector<TrainingConfiguration> configurations(treeFiles.size());

    tbb::parallel_for(tbb::blocked_range<size_t>(0, treeFiles.size(), 1),
            [&](const tbb::blocked_range<size_t>& range) {

                for(size_t tree = range.begin(); tree != range.end(); tree++) {
                    CURFIL_INFO("reading tree " << tree << " from " << treeFiles[tree]);

                    boost::shared_ptr<RandomTreeImage> randomTree;

                    std::string hostname;
                    boost::filesystem::path folderTraining;
                    boost::posix_time::ptime date;

                    TrainingConfiguration configuration = RandomTreeImport::readJSON(treeFiles[tree], randomTree, hostname,
                            folderTraining, date);

                    CURFIL_INFO("trained " << date << " on " << hostname);
                    CURFIL_INFO("training folder: " << folderTraining);

                    assert(randomTree);

                    ensemble[tree] = randomTree;
                    configurations[tree] = configuration;

                    CURFIL_INFO(*randomTree);
                }

            });

    for (size_t i = 1; i < treeFiles.size(); i++) {
        bool strict = false;
        if (!configurations[0].equals(configurations[i], strict)) {
            CURFIL_ERROR("configuration of tree 0: " << configurations[0]);
            CURFIL_ERROR("configuration of tree " << i << ": " << configurations[i]);
            throw std::runtime_error("different configurations");
        }

        if (ensemble[0]->getTree()->getNumClasses() != ensemble[i]->getTree()->getNumClasses()) {
            CURFIL_ERROR("number of classes of tree 0: " << ensemble[0]->getTree()->getNumClasses());
            CURFIL_ERROR("number of classes of tree " << i << ": " << ensemble[i]->getTree()->getNumClasses());
            throw std::runtime_error("different number of classes in trees");
        }
    }

    CURFIL_INFO("training configuration " << configurations[0]);

    this->configuration = configurations[0];
    this->configuration.setDeviceIds(deviceIds);
    this->configuration.setAccelerationMode(accelerationMode);

    normalizeHistograms(histogramBias);
}
// Predicts the label and confidence for a given sample.
void Lidfaces::predict(cv::InputArray src, int& label, double& dist) const
{
    label = -1;
    dist = DBL_MAX;
    std::vector<std::vector<cv::KeyPoint> > keyPoints;
    cv::Mat descriptors;

    std::vector<cv::Mat> imageVector; // A vector containing just one image (this is so we can use the same detectKeypointsAndDescriptors function
    imageVector.push_back(src.getMat());

    // Get SIFT keypoints and LID descriptors
    detectKeypointsAndDescriptors(imageVector, keyPoints, descriptors);


    // Cluster the image using the training centres
    int closestCentroidIndex = 0;
    mCenters.convertTo(mCenters, CV_32FC1);
    descriptors.convertTo(descriptors, CV_32FC1);

    cv::Mat histogramLabels(descriptors.rows, 1, CV_32F);

    // For each descriptor
    for (int descriptorIndex = 0; descriptorIndex < descriptors.rows; ++descriptorIndex)
    {
        // (Give it a classification)
        double smallestDist = DBL_MAX;

        // For each centroid
        for (int centroidIndex = 0; centroidIndex < mCenters.rows; ++centroidIndex)
        {
            // Calculate the distance from the descriptor to the centroid
            double currentDist = cv::norm(
                descriptors.row(descriptorIndex) - mCenters.row(centroidIndex));

            // If it is the smallest distance, remember it and the centroid
            if (currentDist < smallestDist)
            {
                smallestDist = currentDist;
                closestCentroidIndex = centroidIndex;
            }
        }
        histogramLabels.at<float>(descriptorIndex) = closestCentroidIndex;
    }

    assert(histogramLabels.rows == descriptors.rows);

    std::vector<cv::Mat> separatedLabels;
    std::vector<cv::Mat> hists(1);
    separatedLabels.push_back(histogramLabels);

    generateHistograms(hists, separatedLabels, mCenters.rows);
    normalizeHistograms(hists);

    dist = DBL_MAX;
    std::multimap<int, double> distances; // Maps label to distance
    // Compare this histogram against all the other histograms
    for (size_t codebookIndex = 0; codebookIndex < mCodebook.size(); ++codebookIndex)
    {
        // Get dist hist
        double currentDist = cv::compareHist(hists[0], mCodebook[codebookIndex], CV_COMP_CHISQR);
        distances.insert(std::pair<int, double>(mLabels.at<int>(codebookIndex), currentDist));
    }

    // Calculate the smallest average distance
    double smallestAverageDist = DBL_MAX;
    int closestLabel = -1;
    double curDist = 0;
    for (int curLabel = 0; curLabel < mCenters.rows; ++curLabel) // For each curLabel
    {
        if (distances.count(curLabel) == 0) // If this histogram has none of this label
            continue; // Don't bother calculating it
        double totalDist = 0;
        std::pair<std::multimap<int, double>::const_iterator, std::multimap<int, double>::const_iterator> itRange = distances.equal_range(curLabel);
        for (std::multimap<int, double>::const_iterator it = itRange.first; it != itRange.second; ++it)
        {
            totalDist += it->second;
        }
        curDist = totalDist/distances.count(curLabel);
        if (curDist < smallestAverageDist)
        {
            smallestAverageDist = curDist;
            closestLabel = curLabel;
        }
    }

    label = closestLabel;
    dist = smallestAverageDist;
}