void test_dbn(){
    DBN dbn;

    dbn.display();

    std::vector<etl::dyn_vector<double>> images;
    std::vector<uint8_t> labels;

    etl::dyn_vector<double> result(100);

    dbn.pretrain(images, 10);
    dbn.svm_train(images, labels);
    dbn.svm_train(images.begin(), images.end(), labels.begin(), labels.end());
    dbn.svm_grid_search(images, labels);
    dbn.svm_grid_search(images.begin(), images.end(), labels.begin(), labels.end());
    dbn.svm_predict(images[1]);
}
Пример #2
0
void execute(DBN& dbn, task& task, const std::vector<std::string>& actions) {
    print_title("Network");
    dbn.display();

    using dbn_t = std::decay_t<DBN>;

    //Execute all the actions sequentially
    for (auto& action : actions) {
        if (action == "pretrain") {
            print_title("Pretraining");

            if (task.pretraining.samples.empty()) {
                std::cout << "dllp: error: pretrain is not possible without a pretraining input" << std::endl;
                return;
            }

            std::vector<Container> pt_samples;

            //Try to read the samples
            if (!read_samples<Three>(task.pretraining.samples, pt_samples)) {
                std::cout << "dllp: error: failed to read the pretraining samples" << std::endl;
                return;
            }

            if (task.pt_desc.denoising) {
                std::vector<Container> clean_samples;

                //Try to read the samples
                if (!read_samples<Three>(task.pretraining_clean.samples, clean_samples)) {
                    std::cout << "dllp: error: failed to read the clean samples" << std::endl;
                    return;
                }

                //Pretrain the network
                cpp::static_if<dbn_t::layers_t::is_denoising>([&](auto f) {
                    f(dbn).pretrain_denoising(pt_samples.begin(), pt_samples.end(), clean_samples.begin(), clean_samples.end(), task.pt_desc.epochs);
                });
            } else {
                //Pretrain the network
                dbn.pretrain(pt_samples.begin(), pt_samples.end(), task.pt_desc.epochs);
            }
        } else if (action == "train") {
            print_title("Training");

            if (task.training.samples.empty() || task.training.labels.empty()) {
                std::cout << "dllp: error: train is not possible without samples and labels" << std::endl;
                return;
            }

            std::vector<Container> ft_samples;
            std::vector<std::size_t> ft_labels;

            //Try to read the samples
            if (!read_samples<Three>(task.training.samples, ft_samples)) {
                std::cout << "dllp: error: failed to read the training samples" << std::endl;
                return;
            }

            //Try to read the labels
            if (!read_labels(task.training.labels, ft_labels)) {
                std::cout << "dllp: error: failed to read the training labels" << std::endl;
                return;
            }

            using last_layer = typename dbn_t::template layer_type<dbn_t::layers - 1>;

            //Train the network
            cpp::static_if<sgd_possible<last_layer>::value>([&](auto f) {
                auto ft_error = f(dbn).fine_tune(ft_samples, ft_labels, task.ft_desc.epochs);
                std::cout << "Train Classification Error:" << ft_error << std::endl;
            });

        } else if (action == "test") {
            print_title("Testing");

            if (task.testing.samples.empty() || task.testing.labels.empty()) {
                std::cout << "dllp: error: test is not possible without samples and labels" << std::endl;
                return;
            }

            std::vector<Container> test_samples;
            std::vector<std::size_t> test_labels;

            //Try to read the samples
            if (!read_samples<Three>(task.testing.samples, test_samples)) {
                std::cout << "dllp: error: failed to read the test samples" << std::endl;
                return;
            }

            //Try to read the labels
            if (!read_labels(task.testing.labels, test_labels)) {
                std::cout << "dllp: error: failed to read the test labels" << std::endl;
                return;
            }

            auto classes = dbn_t::output_size();

            etl::dyn_matrix<std::size_t, 2> conf(classes, classes, 0.0);

            std::size_t n  = test_samples.size();
            std::size_t tp = 0;

            for (std::size_t i = 0; i < test_samples.size(); ++i) {
                auto sample = test_samples[i];
                auto label  = test_labels[i];

                auto predicted = dbn.predict(sample);

                if (predicted == label) {
                    ++tp;
                }

                ++conf(label, predicted);
            }

            double test_error = (n - tp) / double(n);

            std::cout << "Error rate: " << test_error << std::endl;
            std::cout << "Accuracy: " << (1.0 - test_error) << std::endl
                      << std::endl;

            std::cout << "Results per class" << std::endl;

            double overall = 0.0;

            std::cout << "   | Accuracy | Error rate |" << std::endl;

            for (std::size_t l = 0; l < classes; ++l) {
                std::size_t total = etl::sum(conf(l));
                double acc = (total - conf(l, l)) / double(total);
                std::cout << std::setw(3) << l;
                std::cout << "|" << std::setw(10) << (1.0 - acc) << "|" << std::setw(12) << acc << "|" << std::endl;
                overall += acc;
            }

            std::cout << std::endl;

            std::cout << "Overall Error rate: " << overall / classes << std::endl;
            std::cout << "Overall Accuracy: " << 1.0 - (overall / classes) << std::endl
                      << std::endl;

            std::cout << "Confusion Matrix (%)" << std::endl
                      << std::endl;

            std::cout << "    ";
            for (std::size_t l = 0; l < classes; ++l) {
                std::cout << std::setw(5) << l << " ";
            }
            std::cout << std::endl;

            for (std::size_t l = 0; l < classes; ++l) {
                std::size_t total = etl::sum(conf(l));
                std::cout << std::setw(3) << l << "|";
                for (std::size_t p = 0; p < classes; ++p) {
                    std::cout << std::setw(5) << std::setprecision(2) << 100.0 * (conf(l, p) / double(total)) << "|";
                }
                std::cout << std::endl;
            }
            std::cout << std::endl;
        } else if (action == "save") {
            print_title("Save Weights");

            dbn.store(task.w_desc.file);
            std::cout << "Weights saved" << std::endl;
        } else if (action == "load") {
            print_title("Load Weights");

            dbn.load(task.w_desc.file);
            std::cout << "Weights loaded" << std::endl;
        } else {
            std::cout << "dllp: error: Invalid action: " << action << std::endl;
        }
    }

    //TODO
}