void test_cifar_csvtrain(char *filename, char *weightfile) { network net = parse_network_cfg(filename); if(weightfile){ load_weights(&net, weightfile); } srand(time(0)); data test = load_all_cifar10(); matrix pred = network_predict_data(net, test); int i; for(i = 0; i < test.X.rows; ++i){ image im = float_to_image(32, 32, 3, test.X.vals[i]); flip_image(im); } matrix pred2 = network_predict_data(net, test); scale_matrix(pred, .5); scale_matrix(pred2, .5); matrix_add_matrix(pred2, pred); matrix_to_csv(pred); fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1)); free_data(test); }
void test_mnist_csv(char *filename, char *weightfile) { network net = parse_network_cfg(filename); if(weightfile){ load_weights(&net, weightfile); } srand(time(0)); data test; test = load_mnist_data("data/mnist/t10k-images.idx3-ubyte", "data/mnist/t10k-labels.idx1-ubyte", 10000); matrix pred = network_predict_data(net, test); int i; for(i = 0; i < test.X.rows; ++i){ image im = float_to_image(32, 32, 3, test.X.vals[i]); flip_image(im); } matrix pred2 = network_predict_data(net, test); scale_matrix(pred, .5); scale_matrix(pred2, .5); matrix_add_matrix(pred2, pred); matrix_to_csv(pred); fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1)); free_data(test); }
void train_mnist_distill(char *cfgfile, char *weightfile) { data_seed = time(0); srand(time(0)); float avg_loss = -1; char *base = basecfg(cfgfile); printf("%s\n", base); network net = parse_network_cfg(cfgfile); if(weightfile){ load_weights(&net, weightfile); } printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); char *backup_directory = "backup"; int classes = 10; int N = 50000; int epoch = (*net.seen)/N; data train;// = load_all_mnist10(); matrix soft = csv_to_matrix("results/ensemble.csv"); float weight = .9; scale_matrix(soft, weight); scale_matrix(train.y, 1. - weight); matrix_add_matrix(soft, train.y); while(get_current_batch(net) < net.max_batches || net.max_batches == 0){ clock_t time=clock(); float loss = train_network_sgd(net, train, 1); if(avg_loss == -1) avg_loss = loss; avg_loss = avg_loss*.95 + loss*.05; if(get_current_batch(net)%100 == 0) { printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen); } if(*net.seen/N > epoch){ epoch = *net.seen/N; char buff[256]; sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch); save_weights(net, buff); } if(get_current_batch(net)%100 == 0){ char buff[256]; sprintf(buff, "%s/%s.backup",backup_directory,base); save_weights(net, buff); } } char buff[256]; sprintf(buff, "%s/%s.weights", backup_directory, base); save_weights(net, buff); free_network(net); free(base); free_data(train); }