double ff::FBNN::nntest(const FMatrix& x, const FMatrix& y) { // std::cout << "start nntest" << std::endl; FColumn labels; nnpredict(x,y,labels); FColumn expected = rowMaxIndexes(y); std::vector<int> bad = findUnequalIndexes(labels,expected); // std::cout << "end nntest" << std::endl; return double(bad.size()) / x.rows();//Haven't return bad vector.(nntest.m does) }
double nneval(NN* nn, int numdata, const double** test_x, const double** test_y) { int i, num_err=0; for(i=0; i< numdata; i++){ if ( nnpredict(nn, i, test_x) != indOfMaxVal_double( test_y[i], nn->layer[ nn->n -2 ].units )){ num_err ++; } } return (double) num_err / numdata ; }
int main(int argc, char* argv[]){ int c, i; char *inputImageFile = NULL; char *inputLabelFile = NULL; char *inputModel = "noname.model"; char *outputResult = "result.txt"; while( (c=getopt(argc, argv, "i:y:m:o:h")) != -1 ){ switch(c) { case 'h': printf("usage: \n-i input image file\n-m inputModel\n-o outputResult\n-y (optional)input label file\n\tif provided, it calculate accuracy immediately.\n"); break; case 'i': inputImageFile = optarg; break; case 'y': inputLabelFile = optarg; break; case 'm': inputModel = optarg; break; case 'o': outputResult = optarg; break; case '?': printf("Illegal option\n"); printf("usage: \n-i input image file\n-m inputModel\n-o outputResult\n-y (optional)input label file\n\tif provided, it calculate accuracy immediately.\n"); exit(1); break; default: printf("usage: \n-i input image file\n-m inputModel\n-o outputResult\n-y (optional)input label file\n\tif provided, it calculate accuracy immediately.\n"); exit(1); break; } } int test_num ; double **test_x = mnist_load_data(inputImageFile, &test_num); int dim = 28 * 28; // supposed data dimension // normalization parameters double* mean = (double*) malloc( dim * sizeof(double)); double* sigma = (double*) malloc( dim * sizeof(double)); // load from model NN* mm = importModel(mean, sigma, inputModel); normalize_zscore_apply(test_x, test_num, mean, sigma, dim); if(inputLabelFile != NULL){ double **test_y = mnist_load_labels(inputLabelFile, &test_num); double errorRate = nneval(mm, test_num, (const double**)test_x, (const double**) test_y); printf("Error rate=%f\n", errorRate); } // write result FILE* fp = fopen(outputResult, "w"); for(i=0; i< test_num; i++){ fprintf(fp, "%d\n", nnpredict(mm, i, (const double**)test_x) ); } fclose(fp); return 0; }