void test_dbn() { float learning_rate = 0.1, lrcoef = 0.95; float n_epochs =1000; vector<int> hl; int nrows = 6; int ncols = 6; int outs = 2; matrix2dPtr train_X = matrix2dPtr ( new matrix2d(nrows, ncols) ); matrix2dPtr train_Y = matrix2dPtr( new matrix2d(nrows, outs) ); (*train_X) <<= 1,1,1, 0, 0, 0,1, 0, 1, 0, 0, 0,1, 1, 1, 0, 0, 0,0, 0, 1, 1, 1, 0,0, 0, 1, 1, 0, 0,0, 0, 1, 1, 1, 0; (*train_Y ) <<=1,0,1,0,1,0,0,1,0,1,0,1; hl.push_back(55); hl.push_back(55); matrix2dPtr test_X = matrix2dPtr ( new matrix2d(2, ncols) ); (*test_X) <<= 1,1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0; DBN dbn = DBN(train_X, train_Y, hl); dbn.pretrain(learning_rate, lrcoef, 15, n_epochs); int f_epochs = 200; float finetuneLR = 0.1 , finetuneCoef = 0.95; dbn.finetuning( finetuneLR, finetuneCoef, f_epochs ); matrix2dPtr output ;//= matrix2dPtr (new matrix2d (nrows, outs)); output = dbn.predict( test_X ); printMatrix("final output", output); }
void test_dbn(){ DBN dbn; dbn.display(); std::vector<etl::dyn_vector<double>> images; std::vector<uint8_t> labels; etl::dyn_vector<double> result(100); dbn.pretrain(images, 10); dbn.svm_train(images, labels); dbn.svm_train(images.begin(), images.end(), labels.begin(), labels.end()); dbn.svm_grid_search(images, labels); dbn.svm_grid_search(images.begin(), images.end(), labels.begin(), labels.end()); dbn.svm_predict(images[1]); }
void execute(DBN& dbn, task& task, const std::vector<std::string>& actions) { print_title("Network"); dbn.display(); using dbn_t = std::decay_t<DBN>; //Execute all the actions sequentially for (auto& action : actions) { if (action == "pretrain") { print_title("Pretraining"); if (task.pretraining.samples.empty()) { std::cout << "dllp: error: pretrain is not possible without a pretraining input" << std::endl; return; } std::vector<Container> pt_samples; //Try to read the samples if (!read_samples<Three>(task.pretraining.samples, pt_samples)) { std::cout << "dllp: error: failed to read the pretraining samples" << std::endl; return; } if (task.pt_desc.denoising) { std::vector<Container> clean_samples; //Try to read the samples if (!read_samples<Three>(task.pretraining_clean.samples, clean_samples)) { std::cout << "dllp: error: failed to read the clean samples" << std::endl; return; } //Pretrain the network cpp::static_if<dbn_t::layers_t::is_denoising>([&](auto f) { f(dbn).pretrain_denoising(pt_samples.begin(), pt_samples.end(), clean_samples.begin(), clean_samples.end(), task.pt_desc.epochs); }); } else { //Pretrain the network dbn.pretrain(pt_samples.begin(), pt_samples.end(), task.pt_desc.epochs); } } else if (action == "train") { print_title("Training"); if (task.training.samples.empty() || task.training.labels.empty()) { std::cout << "dllp: error: train is not possible without samples and labels" << std::endl; return; } std::vector<Container> ft_samples; std::vector<std::size_t> ft_labels; //Try to read the samples if (!read_samples<Three>(task.training.samples, ft_samples)) { std::cout << "dllp: error: failed to read the training samples" << std::endl; return; } //Try to read the labels if (!read_labels(task.training.labels, ft_labels)) { std::cout << "dllp: error: failed to read the training labels" << std::endl; return; } using last_layer = typename dbn_t::template layer_type<dbn_t::layers - 1>; //Train the network cpp::static_if<sgd_possible<last_layer>::value>([&](auto f) { auto ft_error = f(dbn).fine_tune(ft_samples, ft_labels, task.ft_desc.epochs); std::cout << "Train Classification Error:" << ft_error << std::endl; }); } else if (action == "test") { print_title("Testing"); if (task.testing.samples.empty() || task.testing.labels.empty()) { std::cout << "dllp: error: test is not possible without samples and labels" << std::endl; return; } std::vector<Container> test_samples; std::vector<std::size_t> test_labels; //Try to read the samples if (!read_samples<Three>(task.testing.samples, test_samples)) { std::cout << "dllp: error: failed to read the test samples" << std::endl; return; } //Try to read the labels if (!read_labels(task.testing.labels, test_labels)) { std::cout << "dllp: error: failed to read the test labels" << std::endl; return; } auto classes = dbn_t::output_size(); etl::dyn_matrix<std::size_t, 2> conf(classes, classes, 0.0); std::size_t n = test_samples.size(); std::size_t tp = 0; for (std::size_t i = 0; i < test_samples.size(); ++i) { auto sample = test_samples[i]; auto label = test_labels[i]; auto predicted = dbn.predict(sample); if (predicted == label) { ++tp; } ++conf(label, predicted); } double test_error = (n - tp) / double(n); std::cout << "Error rate: " << test_error << std::endl; std::cout << "Accuracy: " << (1.0 - test_error) << std::endl << std::endl; std::cout << "Results per class" << std::endl; double overall = 0.0; std::cout << " | Accuracy | Error rate |" << std::endl; for (std::size_t l = 0; l < classes; ++l) { std::size_t total = etl::sum(conf(l)); double acc = (total - conf(l, l)) / double(total); std::cout << std::setw(3) << l; std::cout << "|" << std::setw(10) << (1.0 - acc) << "|" << std::setw(12) << acc << "|" << std::endl; overall += acc; } std::cout << std::endl; std::cout << "Overall Error rate: " << overall / classes << std::endl; std::cout << "Overall Accuracy: " << 1.0 - (overall / classes) << std::endl << std::endl; std::cout << "Confusion Matrix (%)" << std::endl << std::endl; std::cout << " "; for (std::size_t l = 0; l < classes; ++l) { std::cout << std::setw(5) << l << " "; } std::cout << std::endl; for (std::size_t l = 0; l < classes; ++l) { std::size_t total = etl::sum(conf(l)); std::cout << std::setw(3) << l << "|"; for (std::size_t p = 0; p < classes; ++p) { std::cout << std::setw(5) << std::setprecision(2) << 100.0 * (conf(l, p) / double(total)) << "|"; } std::cout << std::endl; } std::cout << std::endl; } else if (action == "save") { print_title("Save Weights"); dbn.store(task.w_desc.file); std::cout << "Weights saved" << std::endl; } else if (action == "load") { print_title("Load Weights"); dbn.load(task.w_desc.file); std::cout << "Weights loaded" << std::endl; } else { std::cout << "dllp: error: Invalid action: " << action << std::endl; } } //TODO }