示例#1
0
// Usage identical to RandomTreeImage class
void RandomForestImage::train(const std::vector<LabeledRGBDImage>& trainLabelImages, size_t numLabels,
        bool trainTreesSequentially) {

    if (trainLabelImages.empty()) {
        throw std::runtime_error("no training images");
    }

    const size_t treeCount = ensemble.size();

    const int numThreads = configuration.getNumThreads();
    tbb::task_scheduler_init init(numThreads);

    CURFIL_INFO("learning image tree ensemble. " << treeCount << " trees with " << numThreads << " threads");

    for (size_t treeNr = 0; treeNr < treeCount; ++treeNr) {
        ensemble[treeNr] = boost::make_shared<RandomTreeImage>(treeNr, configuration);
    }

    RandomSource randomSource(configuration.getRandomSeed());
    const int SEED = randomSource.uniformSampler(0xFFFF).getNext();

    auto train =
            [&](boost::shared_ptr<RandomTreeImage>& tree) {
                utils::Timer timer;
                auto seed = SEED + tree->getId();
                RandomSource randomSource(seed);

                std::vector<LabeledRGBDImage> sampledTrainLabelImages = trainLabelImages;

                if (configuration.getMaxImages() > 0 && static_cast<int>(trainLabelImages.size()) > configuration.getMaxImages()) {
                    ReservoirSampler<LabeledRGBDImage> reservoirSampler(configuration.getMaxImages());
                    Sampler sampler = randomSource.uniformSampler(0, 10 * trainLabelImages.size());
                    for (auto& image : trainLabelImages) {
                        reservoirSampler.sample(sampler, image);
                    }

                    CURFIL_INFO("tree " << tree->getId() << ": sampled " << reservoirSampler.getReservoir().size()
                            << " out of " << trainLabelImages.size() << " images");
                    sampledTrainLabelImages = reservoirSampler.getReservoir();
                }

                tree->train(sampledTrainLabelImages, randomSource, configuration.getSamplesPerImage() / treeCount, numLabels);
                CURFIL_INFO("finished tree " << tree->getId() << " with random seed " << seed << " in " << timer.format(3));
            };

    if (!trainTreesSequentially && numThreads > 1) {
        tbb::parallel_for_each(ensemble.begin(), ensemble.end(), train);
    } else {
        std::for_each(ensemble.begin(), ensemble.end(), train);
    }
}
示例#2
0
void RandomForestImage::normalizeHistograms(const double histogramBias) {

    treeData.clear();

    for (size_t treeNr = 0; treeNr < ensemble.size(); treeNr++) {
        CURFIL_INFO("normalizing histograms of tree " << treeNr <<
                " with " << ensemble[treeNr]->getTree()->countLeafNodes() << " leaf nodes");
        ensemble[treeNr]->normalizeHistograms(histogramBias);
        treeData.push_back(convertTree(ensemble[treeNr]));
    }
}
示例#3
0
RandomForestImage::RandomForestImage(const std::vector<std::string>& treeFiles,
                    const std::vector<int>& deviceIds,
                    const AccelerationMode accelerationMode,
                    const double histogramBias)
 : configuration(), ensemble(treeFiles.size()),
   m_predictionAllocator(boost::make_shared<cuv::pooled_cuda_allocator>())
{

    if (treeFiles.empty()) {
        throw std::runtime_error("cannot construct empty forest");
    }

    std::vector<TrainingConfiguration> configurations(treeFiles.size());

    tbb::parallel_for(tbb::blocked_range<size_t>(0, treeFiles.size(), 1),
            [&](const tbb::blocked_range<size_t>& range) {

                for(size_t tree = range.begin(); tree != range.end(); tree++) {
                    CURFIL_INFO("reading tree " << tree << " from " << treeFiles[tree]);

                    boost::shared_ptr<RandomTreeImage> randomTree;

                    std::string hostname;
                    boost::filesystem::path folderTraining;
                    boost::posix_time::ptime date;

                    TrainingConfiguration configuration = RandomTreeImport::readJSON(treeFiles[tree], randomTree, hostname,
                            folderTraining, date);

                    CURFIL_INFO("trained " << date << " on " << hostname);
                    CURFIL_INFO("training folder: " << folderTraining);

                    assert(randomTree);

                    ensemble[tree] = randomTree;
                    configurations[tree] = configuration;

                    CURFIL_INFO(*randomTree);
                }

            });

    for (size_t i = 1; i < treeFiles.size(); i++) {
        bool strict = false;
        if (!configurations[0].equals(configurations[i], strict)) {
            CURFIL_ERROR("configuration of tree 0: " << configurations[0]);
            CURFIL_ERROR("configuration of tree " << i << ": " << configurations[i]);
            throw std::runtime_error("different configurations");
        }

        if (ensemble[0]->getTree()->getNumClasses() != ensemble[i]->getTree()->getNumClasses()) {
            CURFIL_ERROR("number of classes of tree 0: " << ensemble[0]->getTree()->getNumClasses());
            CURFIL_ERROR("number of classes of tree " << i << ": " << ensemble[i]->getTree()->getNumClasses());
            throw std::runtime_error("different number of classes in trees");
        }
    }

    CURFIL_INFO("training configuration " << configurations[0]);

    this->configuration = configurations[0];
    this->configuration.setDeviceIds(deviceIds);
    this->configuration.setAccelerationMode(accelerationMode);

    normalizeHistograms(histogramBias);
}
示例#4
0
void test(RandomForestImage& randomForest, const std::string& folderTesting,
        const std::string& folderPrediction, const bool useDepthFilling,
        const bool writeProbabilityImages) {

    auto filenames = listImageFilenames(folderTesting);
    if (filenames.empty()) {
        throw std::runtime_error(std::string("found no files in ") + folderTesting);
    }

    CURFIL_INFO("got " << filenames.size() << " files for prediction");

    CURFIL_INFO("label/color map:");
    const auto labelColorMap = randomForest.getLabelColorMap();
    for (const auto& labelColor : labelColorMap) {
        const auto color = LabelImage::decodeLabel(labelColor.first);
        CURFIL_INFO("label: " << static_cast<int>(labelColor.first) << ", color: RGB(" << color << ")");
    }

    tbb::mutex totalMutex;
    utils::Average averageAccuracy;
    utils::Average averageAccuracyWithoutVoid;

    const LabelType numClasses = randomForest.getNumClasses();
    ConfusionMatrix totalConfusionMatrix(numClasses);

    size_t i = 0;

    const bool useCIELab = randomForest.getConfiguration().isUseCIELab();
    CURFIL_INFO("CIELab: " << useCIELab);
    CURFIL_INFO("DepthFilling: " << useDepthFilling);

    bool onGPU = randomForest.getConfiguration().getAccelerationMode() == GPU_ONLY;

    size_t grainSize = 1;
    if (!onGPU) {
        grainSize = filenames.size();
    }

    bool writeImages = true;
    if (folderPrediction.empty()) {
        CURFIL_WARNING("no prediction folder given. will not write images");
        writeImages = false;
    }

    tbb::parallel_for(tbb::blocked_range<size_t>(0, filenames.size(), grainSize),
            [&](const tbb::blocked_range<size_t>& range) {
                for(size_t fileNr = range.begin(); fileNr != range.end(); fileNr++) {
                    const std::string& filename = filenames[fileNr];
                    const auto imageLabelPair = loadImagePair(filename, useCIELab, useDepthFilling);
                    const RGBDImage& testImage = imageLabelPair.getRGBDImage();
                    const LabelImage& groundTruth = imageLabelPair.getLabelImage();
                    LabelImage prediction(testImage.getWidth(), testImage.getHeight());

                    for(int y = 0; y < groundTruth.getHeight(); y++) {
                        for(int x = 0; x < groundTruth.getWidth(); x++) {
                            const LabelType label = groundTruth.getLabel(x, y);
                            if (label >= numClasses) {
                                const auto msg = (boost::format("illegal label in ground truth image '%s' at pixel (%d,%d): %d RGB(%3d,%3d,%3d) (numClasses: %d)")
                                        % filename
                                        % x % y
                                        % static_cast<int>(label)
                                        % LabelImage::decodeLabel(label)[0]
                                        % LabelImage::decodeLabel(label)[1]
                                        % LabelImage::decodeLabel(label)[2]
                                        % static_cast<int>(numClasses)
                                ).str();
                                throw std::runtime_error(msg);
                            }
                        }
                    }

                    boost::filesystem::path fn(testImage.getFilename());
                    const std::string basepath = folderPrediction + "/" + boost::filesystem::basename(fn);

                    cuv::ndarray<float, cuv::host_memory_space> probabilities;

                    prediction = randomForest.predict(testImage, &probabilities, onGPU);

#ifndef NDEBUG
            for(LabelType label = 0; label < randomForest.getNumClasses(); label++) {
                if (!randomForest.shouldIgnoreLabel(label)) {
                    continue;
                }

                // ignored classes must not be predicted as we did not sample them
                for(size_t y = 0; y < probabilities.shape(0); y++) {
                    for(size_t x = 0; x < probabilities.shape(1); x++) {
                        const float& probability = probabilities(label, y, x);
                        assert(probability == 0.0);
                    }
                }
            }
#endif

            if (writeImages && writeProbabilityImages) {
                utils::Profile profile("writeProbabilityImages");
                RGBDImage probabilityImage(testImage.getWidth(), testImage.getHeight());
                for(LabelType label = 0; label< randomForest.getNumClasses(); label++) {

                    if (randomForest.shouldIgnoreLabel(label)) {
                        continue;
                    }

                    for(int y = 0; y < probabilityImage.getHeight(); y++) {
                        for(int x = 0; x < probabilityImage.getWidth(); x++) {
                            const float& probability = probabilities(label, y, x);
                            for(int c=0; c<3; c++) {
                                probabilityImage.setColor(x, y, c, probability);
                            }
                        }
                    }
                    const std::string filename = (boost::format("%s_label_%d.png") % basepath % static_cast<int>(label)).str();
                    probabilityImage.saveColor(filename);
                }
            }

            int thisNumber;

            {
                tbb::mutex::scoped_lock total(totalMutex);
                thisNumber = i++;
            }

            if (writeImages) {
                utils::Profile profile("writeImages");
                testImage.saveColor(basepath + ".png");
                testImage.saveDepth(basepath + "_depth.png");
                groundTruth.save(basepath + "_ground_truth.png");
                prediction.save(basepath + "_prediction.png");
            }

            ConfusionMatrix confusionMatrix(numClasses);
            double accuracy = calculatePixelAccuracy(prediction, groundTruth, true, &confusionMatrix);
            double accuracyWithoutVoid = calculatePixelAccuracy(prediction, groundTruth, false);

            tbb::mutex::scoped_lock lock(totalMutex);

            CURFIL_INFO("prediction " << (thisNumber + 1) << "/" << filenames.size()
                    << " (" << testImage.getFilename() << "): pixel accuracy (without void): " << 100 * accuracy
                    << " (" << 100 * accuracyWithoutVoid << ")");

            averageAccuracy.addValue(accuracy);
            averageAccuracyWithoutVoid.addValue(accuracyWithoutVoid);

            totalConfusionMatrix += confusionMatrix;
        }

    });

    tbb::mutex::scoped_lock lock(totalMutex);
    double accuracy = averageAccuracy.getAverage();
    double accuracyWithoutVoid = averageAccuracyWithoutVoid.getAverage();

    totalConfusionMatrix.normalize();

    CURFIL_INFO(totalConfusionMatrix);

    CURFIL_INFO("pixel accuracy: " << 100 * accuracy);
    CURFIL_INFO("pixel accuracy without void: " << 100 * accuracyWithoutVoid);
}