Example #1
0
void input_1(Message&& _msg) {
#ifndef f_kMeans__1_kMeans
	if (_msg.getType() == "kMeans") {
		cereal::BinaryInputArchive iarchive(_msg.ss);
		std::tuple<std::vector<std::vector<double> >, std::vector<std::vector<double> >, int, double kMeans_DOWN__1_kMeans_types > _data;
		iarchive(std::get<0>(_data), std::get<1>(_data), std::get<2>(_data), std::get<3>(_data) kMeans_DOWN__1_kMeans_tuple_get );
		_1_kMeans(std::get<0>(_data), std::get<1>(_data), std::get<2>(_data), std::get<3>(_data) kMeans_DOWN__1_kMeans_tuple_get );
		return;
	}
#endif
	kMeans_UP_kMeans
}
void test_shared()
{
    boost::shared_ptr<A> ip(new C);
    boost::shared_ptr<A> op1;
    boost::shared_ptr<A> op2;
    {
        std::vector<char> buffer;
        hpx::serialization::output_archive oarchive(buffer);
        oarchive << ip << ip;

        hpx::serialization::input_archive iarchive(buffer);
        iarchive >> op1;
        iarchive >> op2;
    }

    HPX_TEST_NEQ(op1.get(), ip.get());
    HPX_TEST_NEQ(op2.get(), ip.get());
    HPX_TEST_EQ(op1.get(), op2.get());
    HPX_TEST_EQ(op1->foo(), std::string("C::foo"));
    HPX_TEST_EQ(op2->foo(), std::string("C::foo"));
    HPX_TEST_EQ(static_cast<C*>(op1.get())->a, 1);
    HPX_TEST_EQ(static_cast<C*>(op1.get())->b, 2);
    HPX_TEST_EQ(static_cast<C*>(op1.get())->get_c(), 3);
    HPX_TEST_EQ(static_cast<C*>(op2.get())->a, 1);
    HPX_TEST_EQ(static_cast<C*>(op2.get())->b, 2);
    HPX_TEST_EQ(static_cast<C*>(op2.get())->get_c(), 3);
    HPX_TEST_EQ(op1.use_count(), 2);
}
void test(T minval, T maxval)
{
    {
        std::vector<char> buffer;

        hpx::serialization::output_archive oarchive(buffer,
            hpx::serialization::disable_data_chunking);

        std::size_t sz = static_cast<std::size_t>(maxval-minval);

        hpx::partitioned_vector<T> os(sz);
        os.register_as("test_vector");
        hpx::parallel::fill(
            hpx::parallel::execution::par, std::begin(os), std::end(os), 42);

        oarchive << os;

        hpx::serialization::input_archive iarchive(buffer);

        hpx::partitioned_vector<T> is(os.size());
        hpx::parallel::fill(
            hpx::parallel::execution::par, std::begin(is), std::end(is), 0);

        iarchive >> is;

        HPX_TEST_EQ(os.size(), is.size());
        for (std::size_t i = 0; i != os.size(); ++i)
        {
            HPX_TEST_EQ(os[i], is[i]);
        }
    }
}
int main()
{
    std::vector<char> buffer;
    hpx::serialization::output_archive oarchive(buffer);
    oarchive << A();

    B * const b1 = new D;
    oarchive << hpx::serialization::detail::raw_ptr(b1);
    oarchive << hpx::serialization::detail::raw_ptr(b1);

    hpx::serialization::input_archive iarchive(buffer);
    A a;
    iarchive >> a;
    B *b2 = 0, *b3 = 0;
    iarchive >> hpx::serialization::detail::raw_ptr(b2);
    iarchive >> hpx::serialization::detail::raw_ptr(b3);

    HPX_TEST_EQ(a.a, 8);
    HPX_TEST_NEQ(b2, b1);
    HPX_TEST_NEQ(b2, b3); //untracked
    HPX_TEST_EQ(b2->b, b1->b);

    delete b2;

    HPX_TEST_EQ(b1->b, 4711);

    return hpx::util::report_errors();
}
Example #5
0
int main(int argc, char* argv[])
{
    std::size_t size = 0;
    std::vector<double> os;
    {
        file_wrapper buffer("file_serialization_test.archive",
            std::ios_base::out | std::ios_base::binary | std::ios_base::trunc);
        hpx::serialization::output_archive oarchive(buffer);
        for(double c = -100.0; c < +100.0; c += 1.3)
        {
            os.push_back(c);
        }
        oarchive << os;
        size = oarchive.bytes_written();
    }

    {
        file_wrapper buffer("file_serialization_test.archive",
            std::ios_base::in | std::ios_base::binary);
        hpx::serialization::input_archive iarchive(buffer, size);
        std::vector<double> is;
        iarchive >> is;
        for(std::size_t i = 0; i < os.size(); ++i)
        {
            if (os[i] != is[i])
            {
                std::cerr << "Mismatch for element " << i << ":"
                          << os[i] << " != " << is[i] << "\n";
            }
        }
    }
    return 0;
}
void test_intrusive()
{
    boost::intrusive_ptr<D> ip(new F);
    boost::intrusive_ptr<D> op1;
    boost::intrusive_ptr<D> op2;
    {
        std::vector<char> buffer;
        hpx::serialization::output_archive oarchive(buffer);
        oarchive << ip << ip;

        hpx::serialization::input_archive iarchive(buffer);
        iarchive >> op1;
        iarchive >> op2;
    }
    HPX_TEST_NEQ(op1.get(), ip.get());
    HPX_TEST_NEQ(op2.get(), ip.get());
    HPX_TEST_EQ(op1.get(), op2.get());
    HPX_TEST_EQ(op1->foo(), std::string("F::foo"));
    HPX_TEST_EQ(op2->foo(), std::string("F::foo"));
    HPX_TEST_EQ(static_cast<F*>(op1.get())->a, 1);
    HPX_TEST_EQ(static_cast<F*>(op1.get())->b, 2);
    HPX_TEST_EQ(static_cast<F*>(op1.get())->get_c(), 3);
    HPX_TEST_EQ(static_cast<F*>(op2.get())->a, 1);
    HPX_TEST_EQ(static_cast<F*>(op2.get())->b, 2);
    HPX_TEST_EQ(static_cast<F*>(op2.get())->get_c(), 3);

    HPX_TEST_EQ(ip->count, 1);
    HPX_TEST_EQ(op1->count, 2);
    HPX_TEST_EQ(op2->count, 2);
    op1.reset();
    HPX_TEST_EQ(op2->count, 1);
}
Example #7
0
void test_bool()
{
    {
        std::vector<char> buffer;
        hpx::serialization::output_archive oarchive(buffer);

        std::list<bool> os;
        os.push_back(true);
        os.push_back(false);
        os.push_back(false);
        os.push_back(true);
        oarchive << os;

        hpx::serialization::input_archive iarchive(buffer);
        std::list<bool> is;
        iarchive >> is;
        HPX_TEST_EQ(os.size(), is.size());
        auto ot = os.begin();
        auto it = is.begin();
        for(std::size_t i = 0; i < os.size(); ++i)
        {
            HPX_TEST_EQ(*ot, *it);
        }
    }
    {
        std::vector<char> buffer;
        hpx::serialization::output_archive oarchive(buffer);

        std::list<A<bool> > os;
        os.push_back(true);
        os.push_back(false);
        os.push_back(false);
        os.push_back(true);
        oarchive << os;

        hpx::serialization::input_archive iarchive(buffer);
        std::list<A<bool> > is;
        iarchive >> is;
        HPX_TEST_EQ(os.size(), is.size());
        auto ot = os.begin();
        auto it = is.begin();
        for(std::size_t i = 0; i < os.size(); ++i)
        {
            HPX_TEST_EQ(ot->t_, it->t_);
        }
    }
}
	void mo_cmaes_optimizer::deserialize()
	{
		ifstream ifs("save/mocma.sav");
		boost::archive::text_iarchive iarchive(ifs);
		iarchive >> mocma >> max_step_count >> current_step_count;
		ifs.close();
		boost::filesystem::remove_all("save/");
		logger::get_logger().log_info("Settings were successfully loaded!");
	}
 /**
  * Read the index version info from file
  */
 bool load(boost::filesystem::path& versionFile) {
     namespace bfs = boost::filesystem;
     if(!bfs::exists(versionFile)) {
         fmt::MemoryWriter infostr;
         infostr << "Error: The index version file " << versionFile.string()
                 << " doesn't seem to exist.  Please try re-building the sailfish "
                 "index.";
         throw std::invalid_argument(infostr.str());
     }
     std::ifstream ifs(versionFile.string());
     {
         cereal::JSONInputArchive iarchive(ifs); // Create an input archive
         iarchive(cereal::make_nvp("indexVersion", indexVersion_),
                  cereal::make_nvp("kmerLength", kmerLength_));
     }
     ifs.close();
     return true;
 }
Example #10
0
void test(T min, T max)
{
    {
        std::vector<char> buffer;
        hpx::serialization::output_archive oarchive(buffer);
        std::list<T> os;
        for(T c = min; c < max; ++c)
        {
            os.push_back(c);
        }
        oarchive << os;
        hpx::serialization::input_archive iarchive(buffer);
        std::list<T> is;
        iarchive >> is;
        HPX_TEST_EQ(os.size(), is.size());
        auto ot = os.begin();
        auto it = is.begin();
        for(std::size_t i = 0; i < os.size(); ++i)
        {
            HPX_TEST_EQ(*ot, *it);
        }
    }
    {
        std::vector<char> buffer;
        hpx::serialization::output_archive oarchive(buffer);
        std::list<A<T> > os;
        for(T c = min; c < max; ++c)
        {
            os.push_back(c);
        }
        oarchive << os;
        hpx::serialization::input_archive iarchive(buffer);
        std::list<A<T> > is;
        iarchive >> is;
        HPX_TEST_EQ(os.size(), is.size());
        auto ot = os.begin();
        auto it = is.begin();
        for(std::size_t i = 0; i < os.size(); ++i)
        {
            HPX_TEST_EQ(ot->t_, it->t_);
        }
    }
}
TEST(Serialization, OtherBase) {
    std::stringstream ss;
    {
        cereal::PortableBinaryOutputArchive oarchive(ss);

        std::shared_ptr<OtherBase> ptr = std::make_shared<Derived>(3, 4, 5);

        oarchive(ptr);
    }

    {
        cereal::PortableBinaryInputArchive iarchive(ss);

        std::shared_ptr<OtherBase> obj;
        iarchive(obj);

        Derived *d = dynamic_cast<Derived*>(obj.get());
        ASSERT_NE(nullptr, d);
        ASSERT_EQ(3, d->derivedVal);
        ASSERT_EQ(4, d->baseVal);
        ASSERT_EQ(5, d->otherBaseVal);
    }
}
void test_template()
{
    std::vector<char> buffer;
    {
        std::shared_ptr<C<float> > struct_a(new C<float>(777));
        hpx::serialization::output_archive oarchive(buffer);
        oarchive << struct_a;
    }
    {
        std::shared_ptr<C<float> > struct_b;
        hpx::serialization::input_archive iarchive(buffer);
        iarchive >> struct_b;
        HPX_TEST_EQ(struct_b->c, 777);
    }
}
void test_delegate()
{
    std::vector<char> buffer;
    {
        std::shared_ptr<A> struct_a(new A(4711));
        hpx::serialization::output_archive oarchive(buffer);
        oarchive << struct_a;
    }
    {
        std::shared_ptr<A> struct_b;
        hpx::serialization::input_archive iarchive(buffer);
        iarchive >> struct_b;
        HPX_TEST_EQ(struct_b->a, 4711);
    }
}
void test_member()
{
    std::vector<char> buffer;
    {
        boost::shared_ptr<A> struct_a(new E<float>(1, 2.3f));
        hpx::serialization::output_archive oarchive(buffer);
        oarchive << struct_a;
    }
    {
        boost::shared_ptr<A> struct_b;
        hpx::serialization::input_archive iarchive(buffer);
        iarchive >> struct_b;
        HPX_TEST_EQ(struct_b->a, 1);
        HPX_TEST_EQ(dynamic_cast<E<float>*>(&*struct_b)->c.c, 2.3f);
    }
}
void test_custom_factory()
{
    std::vector<char> buffer;

    {
        std::shared_ptr<B> struct_a(new B(1981, false));
        hpx::serialization::output_archive oarchive(buffer);
        oarchive << struct_a;
    }
    {
        std::shared_ptr<B> struct_b;
        hpx::serialization::input_archive iarchive(buffer);
        iarchive >> struct_b;
        HPX_TEST_EQ(struct_b->b, 1981);
    }
}
Example #16
0
ReturnCode TimerTable::Load()
{
    std::ifstream ifs(impl_->config_.GetSaveFileName(), std::ios::binary);
    if (!ifs) {
        return kReturnCodeFileOpenFailed;
    }

    Serialization::IArchive iarchive(ifs);
    if (iarchive.GetVersion() != SerialVersion()) {
        return kReturnCodeFileReadFailed;
    }

    iarchive >> impl_->timer_table_manager_
             >> impl_->category_manager_;
    if (ifs.fail()) {
        return kReturnCodeFileReadFailed;
    }
    return kReturnCodeOK;
}
void test_fp(T min, T max)
{
    {
        std::vector<char> buffer;
        hpx::serialization::output_archive oarchive(buffer);
        std::unordered_map<T, A<T> > os;
        for(T c = min; c < max; c += static_cast<T>(0.5))
        {
            os.insert(std::make_pair(c, A<T>(c)));
        }
        oarchive << os;
        hpx::serialization::input_archive iarchive(buffer);
        std::unordered_map<T, A<T> > is;
        iarchive >> is;
        HPX_TEST_EQ(os.size(), is.size());
        for (const auto& v: os)
        {
            HPX_TEST_EQ(os[v.first], is[v.first]);
        }
    }
}
void test_basic()
{
    std::vector<char> buffer;
    hpx::serialization::output_archive oarchive(buffer);
    oarchive << A();
    D d;
    B const & b1 = d;
    oarchive << b1;

    hpx::serialization::input_archive iarchive(buffer);
    A a;
    iarchive >> a;
    D d1;
    B & b2 = d1;
    iarchive >> b2;
    HPX_TEST_EQ(a.a, 8);
    HPX_TEST_EQ(&b2, &d1);
    HPX_TEST_EQ(b2.b, d1.b);
    HPX_TEST_EQ(d.b, d1.b);
    HPX_TEST_EQ(d.d, d1.d);
}
void test_vector_as_value()
{
    std::vector<char> buffer;
    hpx::serialization::output_archive oarchive(buffer);
    std::map<size_t, std::vector<int> > os;
    for (int k = 0; k < 10; ++k)
    {
        std::vector<int> vec(10);
        std::iota(vec.begin(), vec.end(), k);
        os.insert(std::make_pair(k, vec));
    }
    oarchive << os;
    hpx::serialization::input_archive iarchive(buffer);
    std::unordered_map<size_t, std::vector<int> > is;
    iarchive >> is;
    HPX_TEST_EQ(os.size(), is.size());
    for (const auto& v: os)
    {
        HPX_TEST_EQ(os[v.first], is[v.first]);
    }
}
int main()
{
    std::vector<char> buffer;
    hpx::serialization::output_archive oarchive(buffer);
    oarchive << A();
    D d;
    B const & b1 = d;
    oarchive << b1;

    hpx::serialization::input_archive iarchive(buffer);
    A a;
    iarchive >> a;
    D d1;
    B & b2 = d1;
    iarchive >> b2;
    HPX_TEST_EQ(a.a, 8);
    HPX_TEST_EQ(&b2, &d1);
    HPX_TEST_EQ(b2.b, d1.b);
    HPX_TEST_EQ(d.b, d1.b);
    HPX_TEST_EQ(d.d, d1.d);

    return hpx::util::report_errors();
}
Example #21
0
void ResultsRouterThread::operator ()()
{
	//Connect to the analytic output queue
	bool bConnectResult = mqPtr->connectTo(_analyticOutQueueAddress,ZMQ_PULL);

	if(bConnectResult){

		//cout << "ResultsRouterThread::operator: "<<"Results router thread for analytic instance " << _iAnalyticInstId << " started." << endl;
		//cout << "Results router thread for analytic instance " << _iAnalyticInstId << " started." << endl;

		cout << "Results router thread : Start reading the images from the analytic process  : " << _iAnalyticInstId << endl << endl;

		AnalyticResultGateway _analyticResultGateway;
		Image imageResultObj;

		ostringstream filePath;
		filePath << "/usr/local/opencctv/images/";
		filePath << _iAnalyticInstId;
		filePath << "/";

		//TODO : Add the error handling to mkdir
		mkdir(filePath.str().c_str(), 0775);

		int iFrameCount = 1;

		//Read the output results and store in the OpenCCTV database
		while(1)		{

			//Read serialized results image object
			string serializedImageStr = mqPtr->read();

			//Initialize with received serialized string data
			std::istringstream ibuffer(serializedImageStr);

			//De-serialize and create image object
			boost::archive::text_iarchive iarchive(ibuffer);
			iarchive & imageResultObj;

			//TODO Define the image store location in a config file
			//1. Write the images to the path /usr/local/opencctv/images
			Mat matResultObj = JpegImage::toOpenCvMat(imageResultObj);

			ostringstream filename;
			filename << filePath.str();
			filename << imageResultObj.getTimestamp() << "_";
			filename << iFrameCount;
			filename << ".jpg";

			imwrite(filename.str(), matResultObj);

			++iFrameCount;

			//2. Store the image detais in the results DB
			_analyticResultGateway.insertResults(_iAnalyticInstId,imageResultObj.getTimestamp(),imageResultObj.getResult(),filename.str());

		}

	}else{
		cerr << "ResultsRouterThread:operator: Error connecting to the analytic output queue......." << endl;
	}
}
int main(int argc, const char* argv[]) {
    try {
        // Parse command line arguments.
        TCLAP::CmdLine cmd("Depth RF trainer", ' ', "0.3");
        TCLAP::ValueArg<std::string> image_list_file_arg("f", "image-list-file", "File containing the names of image files", true, "", "string", cmd);
        TCLAP::ValueArg<int> num_of_classes_arg("n", "num-of-classes", "Number of classes in the data", true, 1, "int", cmd);
        TCLAP::SwitchArg print_confusion_matrix_switch("m", "conf-matrix", "Print confusion matrix", cmd, true);
        TCLAP::ValueArg<int> background_label_arg("l", "background-label", "Lower bound of background labels to be ignored", false, -1, "int", cmd);
        TCLAP::ValueArg<std::string> json_forest_file_arg("j", "json-forest-file", "JSON file where the trained forest should be saved", false, "forest.json", "string");
        TCLAP::ValueArg<std::string> binary_forest_file_arg("b", "binary-forest-file", "Binary file where the trained forest should be saved", false, "forest.bin", "string");
        TCLAP::ValueArg<std::string> config_file_arg("c", "config", "YAML file with training parameters", false, "", "string", cmd);
#if AIT_MULTI_THREADING
        TCLAP::ValueArg<int> num_of_threads_arg("t", "threads", "Number of threads to use", false, -1, "int", cmd);
#endif
        cmd.xorAdd(json_forest_file_arg, binary_forest_file_arg);
        cmd.parse(argc, argv);
        
        const int num_of_classes = num_of_classes_arg.getValue();
        bool print_confusion_matrix = print_confusion_matrix_switch.getValue();
        const std::string image_list_file = image_list_file_arg.getValue();

        // Initialize training and weak-learner parameters to defaults or load from file
        ForestTrainerT::ParametersT training_parameters;
        WeakLearnerT::ParametersT weak_learner_parameters;
        if (config_file_arg.isSet()) {
            ait::log_info(false) << "Reading config file " << config_file_arg.getValue() << "... " << std::flush;
            std::ifstream ifile_config(config_file_arg.getValue());
            cereal::JSONInputArchive iarchive(ifile_config);
            iarchive(cereal::make_nvp("training_parameters", training_parameters));
            iarchive(cereal::make_nvp("weak_learner_parameters", weak_learner_parameters));
            ait::log_info(false) << " Done." << std::endl;
        }
#if AIT_MULTI_THREADING
        if (num_of_threads_arg.isSet()) {
            training_parameters.num_of_threads = num_of_threads_arg.getValue();
        }
#endif

        // Read image file list
        ait::log_info(false) << "Reading image list ... " << std::flush;
        std::vector<std::tuple<std::string, std::string>> image_list;
        std::ifstream ifile(image_list_file);
        if (!ifile.good()) {
            throw std::runtime_error("Unable to open image list file");
        }
        ait::CSVReader<std::string> csv_reader(ifile);
        for (auto it = csv_reader.begin(); it != csv_reader.end(); ++it) {
            if (it->size() != 2) {
                cmd.getOutput()->usage(cmd);
                ait::log_error() << "Image list file should contain two columns with the data and label filenames.";
                exit(-1);
            }
            const std::string& data_filename = (*it)[0];
            const std::string& label_filename = (*it)[1];
            
            boost::filesystem::path data_path = boost::filesystem::path(data_filename);
            boost::filesystem::path label_path = boost::filesystem::path(label_filename);
            if (!data_path.is_absolute()) {
                data_path = boost::filesystem::path(image_list_file).parent_path();
                data_path /= data_filename;
            }
            if (!label_path.is_absolute()) {
                label_path = boost::filesystem::path(image_list_file).parent_path();
                label_path /= label_filename;
            }
            
            image_list.push_back(std::make_tuple(data_path.string(), label_path.string()));
        }
        ait::log_info(false) << " Done." << std::endl;
        
        // TODO: Ensure that label images do not contain values > num_of_classes except for background pixels. Other approach: Test samples directly below.
        
        // Set lower bound for background pixel lables
        ait::label_type background_label;
        if (background_label_arg.isSet()) {
            background_label = background_label_arg.getValue();
        } else {
            background_label = num_of_classes;
        }
        weak_learner_parameters.background_label = background_label;

        // Create weak learner and trainer.
        StatisticsT::Factory statistics_factory(num_of_classes);
        WeakLearnerT iwl(weak_learner_parameters, statistics_factory);
        ForestTrainerT trainer(iwl, training_parameters);
        SampleProviderT sample_provider(image_list, weak_learner_parameters);
        BaggingWrapperT bagging_wrapper(trainer, sample_provider);

#ifdef AIT_TESTING
        RandomEngineT rnd_engine(11);
#else
        std::random_device rnd_device;
        ait::log_info() << "rnd(): " << rnd_device();
        RandomEngineT rnd_engine(rnd_device());
#endif

        // Train a forest and time it.
        auto start_time = std::chrono::high_resolution_clock::now();
        // TODO
        //		ForestTrainerT::ForestT forest = bagging_wrapper.train_forest(rnd_engine);
        // TODO: Testing all samples for comparison with depth_trainer
        sample_provider.clear_samples();
        for (int i = 0; i < image_list.size(); ++i) {
            sample_provider.load_samples_from_image(i, rnd_engine);
        }
        SampleIteratorT samples_start = sample_provider.get_samples_begin();
        SampleIteratorT samples_end = sample_provider.get_samples_end();
        ait::log_info() << "Starting training ...";
        ForestTrainerT::ForestT forest = trainer.train_forest(samples_start, samples_end, rnd_engine);
        auto stop_time = std::chrono::high_resolution_clock::now();
        auto duration = stop_time - start_time;
        auto period = std::chrono::high_resolution_clock::period();
        double elapsed_seconds = duration.count() * period.num / static_cast<double>(period.den);
        ait::log_info() << "Done.";
        ait::log_info() << "Running time: " << elapsed_seconds;
        
        // Optionally: Serialize forest to JSON file.
        if (json_forest_file_arg.isSet()) {
            {
                ait::log_info(false) << "Writing json forest file " << json_forest_file_arg.getValue() << "... " << std::flush;
                std::ofstream ofile(json_forest_file_arg.getValue());
                cereal::JSONOutputArchive oarchive(ofile);
                oarchive(cereal::make_nvp("forest", forest));
                ait::log_info(false) << " Done." << std::endl;
            }
        // Optionally: Serialize forest to binary file.
        } else if (binary_forest_file_arg.isSet()) {
            {
                ait::log_info(false) << "Writing binary forest file " << binary_forest_file_arg.getValue() << "... " << std::flush;
                std::ofstream ofile(binary_forest_file_arg.getValue(), std::ios_base::binary);
                cereal::BinaryOutputArchive oarchive(ofile);
                oarchive(cereal::make_nvp("forest", forest));
                ait::log_info(false) << " Done." << std::endl;
            }
        } else {
            throw("This should never happen. Either a JSON or a binary forest file have to be specified!");
        }

        // Optionally: Compute some stats and print them.
        if (print_confusion_matrix) {
            ait::log_info(false) << "Creating samples for testing ... " << std::flush;
            sample_provider.clear_samples();
            for (int i = 0; i < image_list.size(); ++i) {
                sample_provider.load_samples_from_image(i, rnd_engine);
            }
            SampleIteratorT samples_start = sample_provider.get_samples_begin();
            SampleIteratorT samples_end = sample_provider.get_samples_end();
            ait::log_info(false) << " Done." << std::endl;
            
            std::vector<ait::size_type> sample_counts(num_of_classes, 0);
            for (auto sample_it = samples_start; sample_it != samples_end; sample_it++) {
                ++sample_counts[sample_it->get_label()];
            }
            auto logger = ait::log_info(true);
            logger << "Sample counts>> ";
            for (int c = 0; c < num_of_classes; ++c) {
                if (c > 0) {
                    logger << ", ";
                }
                logger << "class " << c << ": " << sample_counts[c];
            }
            logger.close();
            // For each tree extract leaf node indices for each sample.
            std::vector<std::vector<ait::size_type>> forest_leaf_indices = forest.evaluate(samples_start, samples_end);
            
            // Compute number of prediction matches based on a majority vote among the forest.
            int match = 0;
            int no_match = 0;
            for (auto tree_it = forest.cbegin(); tree_it != forest.cend(); ++tree_it) {
                for (auto sample_it = samples_start; sample_it != samples_end; sample_it++) {
                    const auto &node_it = tree_it->cbegin() + (forest_leaf_indices[tree_it - forest.cbegin()][sample_it - samples_start]);
                    const auto &statistics = node_it->get_statistics();
                    auto max_it = std::max_element(statistics.get_histogram().cbegin(), statistics.get_histogram().cend());
                    auto label = max_it - statistics.get_histogram().cbegin();
                    if (label == sample_it->get_label()) {
                        match++;
                    } else {
                        no_match++;
                    }
                }
            }
            ait::log_info() << "Match: " << match << ", no match: " << no_match;
            
            // Compute confusion matrix.
            auto forest_utils = ait::make_forest_utils(forest);
            auto confusion_matrix = forest_utils.compute_confusion_matrix(samples_start, samples_end);
            ait::log_info() << "Confusion matrix:" << std::endl << confusion_matrix;
            auto norm_confusion_matrix = ait::EvaluationUtils::normalize_confusion_matrix(confusion_matrix);
            ait::log_info() << "Normalized confusion matrix:" << std::endl << norm_confusion_matrix;
            ait::log_info() << "Diagonal of normalized confusion matrix:" << std::endl << norm_confusion_matrix.diagonal();
            
            // Computing per-frame confusion matrix
            ait::log_info() << "Computing per-frame confusion matrix.";
            using ConfusionMatrixType = typename decltype(forest_utils)::MatrixType;
            ConfusionMatrixType per_frame_confusion_matrix(num_of_classes, num_of_classes);
            per_frame_confusion_matrix.setZero();
            WeakLearnerT::ParametersT full_parameters(weak_learner_parameters);
            // Modify parameters to retrieve all pixels per sample
            full_parameters.samples_per_image_fraction = 1.0;
            SampleProviderT full_sample_provider(image_list, full_parameters);
            for (int i = 0; i < image_list.size(); ++i) {
                full_sample_provider.clear_samples();
                full_sample_provider.load_samples_from_image(i, rnd_engine);
                samples_start = full_sample_provider.get_samples_begin();
                samples_end = full_sample_provider.get_samples_end();
                forest_utils.update_confusion_matrix(per_frame_confusion_matrix, samples_start, samples_end);
            }
            ait::log_info() << "Per-frame confusion matrix:" << std::endl << per_frame_confusion_matrix;
            ConfusionMatrixType per_frame_norm_confusion_matrix = ait::EvaluationUtils::normalize_confusion_matrix(per_frame_confusion_matrix);
            ait::log_info() << "Normalized per-frame confusion matrix:" << std::endl << per_frame_norm_confusion_matrix;
            ait::log_info() << "Diagonal of normalized per-frame confusion matrix:" << std::endl << per_frame_norm_confusion_matrix.diagonal();
            ait::log_info() << "Mean of diagonal of normalized per-frame confusion matrix:" << std::endl << per_frame_norm_confusion_matrix.diagonal().mean();
        }

    } catch (const std::runtime_error& error) {
        std::cerr << "Runtime exception occured" << std::endl;
        std::cerr << error.what() << std::endl;
    }
    
    return 0;
}
Example #23
0
void LALRTable::Load(const char* path) {
    std::ifstream is(path, std::ios::in | std::ios::binary);
    cereal::PortableBinaryInputArchive iarchive(is);
    iarchive( cereal::make_nvp("myData", *this) );
}