コード例 #1
0
ファイル: rbm.cpp プロジェクト: hakimsd9/RBM
void rbm::load_data(std::string data_file){
    // Load in binary format:
    // if feature == 0: 0 , 1 otherwise
    std::string feature_file(data_file);
	std::string label_file(data_file);
	feature_file += ".fea";
	label_file += ".labels";

    input_features = new float*[number_of_data_points];
    for (int i=0; i<number_of_data_points; i++){
        input_features[i] = new float[num_visible_units];
    }

    std::ifstream infile(feature_file.c_str(),std::ios::in);
    for (int i=0; i<number_of_data_points; i++){
        for (int j=0; j<num_visible_units; j++){
            infile >> input_features[i][j];
            // Binary RBM
            if (input_features[i][j] > 0){
                input_features[i][j] = 1;
            }
        }
    }
    infile.close();

    int * output_labels = new int[number_of_data_points];
    FILE * tmpfp = fopen(label_file.c_str(),"r");
	for (int i = 0; i < number_of_data_points; i++){
		fscanf(tmpfp, "%d", &output_labels[i]);
	}
	fclose(tmpfp);
}
コード例 #2
0
ファイル: util.cpp プロジェクト: hakimsd9/RBM
// Convert the input txt file (just after conversion from .mat)
// into libsvm format
// so as to compare resuls with visible units to the onces with new learned features
void transform_to_libsvm_format(std::string data_file, int sample_size, int num_visible_units, std::string libsvm_features_file){
    std::string feature_file(data_file);
	std::string label_file(data_file);
	feature_file += ".fea";
	label_file += ".labels";

    float ** input_f = new float*[sample_size];
    for (int i=0; i<sample_size; i++){
        input_f[i] = new float[num_visible_units];
    }

    std::ifstream feat_file(feature_file.c_str(),std::ios::in);
    for (int i=0; i<sample_size; i++){
        for (int j=0; j<num_visible_units; j++){
            feat_file >> input_f[i][j];
            // Binary RBM
            if (input_f[i][j] > 0){
                input_f[i][j] = 1;
            }
        }
    }
    feat_file.close();

    int * output_labels = new int[sample_size];
    FILE * tmpfp = fopen(label_file.c_str(),"r");
	for (int i = 0; i < sample_size; i++){
		fscanf(tmpfp, "%d", &output_labels[i]);
	}
	fclose(tmpfp);

    // write the features to libsvm_features_file
    // libsvm uses a sparse format
    std::ofstream out_file(libsvm_features_file);
    if (out_file.is_open()){
        for (int i=0; i<sample_size; i++){
            int non_zero_features = 0;
            for (int j=0; j<num_visible_units; j++){
                if (input_f[i][j] != 0){
                    non_zero_features++;
                    if (non_zero_features == 1){
                        out_file << output_labels[i];
                    }
                    out_file << " " << (j+1) << ":" << input_f[i][j];   // indexing starts at 1 in libsvm
                }
            }
            if (non_zero_features > 0){
                out_file << std::endl;
            }
        }
        out_file.close();
    }
}
コード例 #3
0
DataProviderDTang<Dtype>::DataProviderDTang(const pose::DataProviderParameter& param) : DataProvider<Dtype>(param) {
  std::cout << "[INFO] read label file " << param.dtang_param().label_path() << std::endl;
  
  boost::filesystem::path label_path = param.dtang_param().label_path();
  
  std::ifstream label_file(label_path.c_str());
  std::string line;
  
  int idx = 0;
  
  while(std::getline(label_file, line)) {
    std::istringstream iss(line);
    
    std::string relative_depth_path;
    iss >> relative_depth_path;

    boost::filesystem::path depth_path = label_path.parent_path() / "Depth" / 
      boost::filesystem::path(relative_depth_path);
    
    std::vector<cv::Vec3f> anno(param.n_pts());
    for(int pt_idx = 0; pt_idx < anno.size(); ++pt_idx) {
      iss >> anno[pt_idx](0);
      iss >> anno[pt_idx](1);
      iss >> anno[pt_idx](2);
    }
    
    //only load original data - no rotated
    if(depth_path.parent_path().parent_path().filename().string() == "Depth" && (idx % this->param_.inc() == 0)) {
      depth_paths_.push_back(depth_path);
      annos_.push_back(anno);
    }
    
    idx++;
  }
  
  this->max_idx_ = depth_paths_.size();
  
  label_file.close();
  
  std::cout << "[INFO] label file " << label_path << " contained " << depth_paths_.size() << " annotated depth images" << std::endl;       
}
コード例 #4
0
ファイル: rbm.cpp プロジェクト: hakimsd9/RBM
// Given learned parameters and new input vectors, compute the hidden nodes
// and save them in a libsvm input format (learned_features)
void rbm::generate_features(std::string data_file, const char * mdl_weight_file,
                const char * mdl_visible_bias_file, const char * mdl_hidden_bias_file,
                const char * learned_features_file, int number_of_examples){
    std::string feature_file(data_file);
	std::string label_file(data_file);
	feature_file += ".fea";
	label_file += ".labels";

    // load Weights into weights matrix
    float ** weights = new float*[num_hidden_units];
    for (int i=0; i<num_hidden_units; i++){
        weights[i] = new float[num_visible_units];
    }
    std::ifstream weight_file(mdl_weight_file,std::ios::in);
    for (int i=0; i<num_hidden_units; i++){
        for (int j=0; j<num_visible_units; j++){
            weight_file >> weights[i][j];
        }
    }
    weight_file.close();

    // load hidden bias into hidden bias vector
    float * hidden_bias = new float[num_hidden_units];
    std::ifstream hidden_bias_file(mdl_hidden_bias_file,std::ios::in);
    for (int i=0; i<num_hidden_units; i++){
        hidden_bias_file >> hidden_bias[i];
    }
    hidden_bias_file.close();

    float ** input_f = new float*[number_of_examples];
    for (int i=0; i<number_of_examples; i++){
        input_f[i] = new float[num_visible_units];
    }
    std::ifstream feat_file(feature_file.c_str(),std::ios::in);
    for (int i=0; i<number_of_examples; i++){
        for (int j=0; j<num_visible_units; j++){
            feat_file >> input_f[i][j];
            // Binary RBM
            if (input_f[i][j] > 0){
                input_f[i][j] = 1;
            }
        }
    }
    feat_file.close();

    int * output_labels = new int[number_of_examples];
    FILE * tmpfp = fopen(label_file.c_str(),"r");
	for (int i = 0; i < number_of_examples; i++){
		fscanf(tmpfp, "%d", &output_labels[i]);
	}
	fclose(tmpfp);

    float ** hidden = new float*[number_of_examples];
    for (int i=0; i<number_of_examples; i++){
        hidden[i] = new float[num_hidden_units];
    }

    // for each example, compute the corresponding hidden features
    for (int i=0; i<number_of_examples; i++){
        compute_hidden(weights, hidden_bias, input_f[i], hidden[i], num_hidden_units, num_visible_units);
    }

    // write the learned features to learned_features
    // libsvm uses a sparse format
    std::ofstream out_file(learned_features_file);
    if (out_file.is_open()){
        for (int i=0; i<number_of_examples; i++){
            int non_zero_features = 0;
            for (int j=0; j<num_hidden_units; j++){
//                if (hidden[i][j] > 0.0001 || hidden[i][j] < -0.0001){ // TODO temp/ Should be != 0
                if (hidden[i][j] != 0){
                    non_zero_features++;
                    if (non_zero_features == 1){
                        out_file << output_labels[i];
                    }
                    out_file << " " << (j+1) << ":" << hidden[i][j];    // indexing starts at 1 in libsvm
                }
            }
            if (non_zero_features > 0){
                out_file << std::endl;
            }
        }
    }
    out_file.close();

    //release data
    delete[] input_f;
}
コード例 #5
0
ファイル: main.cpp プロジェクト: imthexie/giraffe
int main(int argc, char **argv)
{
	InitializeFast();

	Backend backend;

	ANNEvaluator evaluator;

	ANNMoveEvaluator mevaluator(evaluator);

	// if eval.net exists, use the ANN evaluator
	// if both eval.net and meval.net exist, use the ANN move evaluator

	if (FileReadable(EvalNetFilename))
	{
		backend.SetEvaluator(&evaluator);

		std::cout << "# Using ANN evaluator" << std::endl;

		if (FileReadable(MoveEvalNetFilename))
		{
			std::cout << "# Using ANN move evaluator" << std::endl;
			backend.SetMoveEvaluator(&mevaluator);
		}
		else
		{
			std::cout << "# Using static move evaluator" << std::endl;
			backend.SetMoveEvaluator(&gStaticMoveEvaluator);
		}
	}
	else
	{
		std::cout << "# Using static evaluator" << std::endl;
		std::cout << "# Using static move evaluator" << std::endl;

		backend.SetEvaluator(&Eval::gStaticEvaluator);
		backend.SetMoveEvaluator(&gStaticMoveEvaluator);
	}

	// first we handle special operation modes
	if (argc >= 2 && std::string(argv[1]) == "tdl")
	{
		InitializeSlowBlocking(evaluator, mevaluator);

		if (argc < 3)
		{
			std::cout << "Usage: " << argv[0] << " tdl positions" << std::endl;
			return 0;
		}
        //try 
        //{
	    Learn::TDL(argv[2]);
        //}
        //catch(...){}
		return 0;
	}
	else if (argc >= 2 && std::string(argv[1]) == "conv")
	{
		InitializeSlowBlocking(evaluator, mevaluator);

		if (argc < 3)
		{
			std::cout << "Usage: " << argv[0] << " conv FEN" << std::endl;
			return 0;
		}

		std::stringstream ss;

		for (int i = 2; i < argc; ++i)
		{
			ss << argv[i] << ' ';
		}

		Board b(ss.str());

		std::vector<FeaturesConv::FeatureDescription> ret;
		FeaturesConv::ConvertBoardToNN(b, ret);

		return 0;
	}
    else if(argc >= 2 && std::string(argv[1]) == "train_eval")
    {
        // ./giraffe train_eval <list filename> <epd_data_path> <epd_label_path> <epochs> 

        InitializeSlowBlocking(evaluator, mevaluator);
        std::ifstream epd_file_list(argv[2]);
        std::string epd_data_path = argv[3];
        std::string epd_label_path = argv[4];
        std::string epd_filename;
        std::vector<std::string> filenames;
        while(std::getline(epd_file_list, epd_filename)){
            filenames.push_back(epd_filename);
        }
        int epochs = std::stoi(argv[5]);
        
        for(int i = 0; i < epochs*filenames.size(); i++){
            std::string epd_path_full = epd_data_path + "/" + filenames[i % filenames.size()];
            std::string label_path_full = epd_label_path + "/" + filenames[i % filenames.size()] + ".xie";
            std::cout << label_path_full << std::endl;
            std::ifstream epd_file(epd_path_full);
            std::ifstream label_file(label_path_full);
            std::string fen;
            std::vector<std::string> fens;
            //std::cout << "Reading FENS" << std::endl;
            while(std::getline(epd_file, fen))
            {
                fens.push_back(fen);
            } 
            //std::cout << "Reading labels" << std::endl;
            
            std::string label;
            NNMatrixRM mat_labels = NNMatrixRM(fens.size(), 1);
            //std::vector<int> labels; 
            int idx = 0;
            while(std::getline(label_file, label)){
                //labels.push_back(stoi(label));
                mat_labels(idx, 0) = std::stoi(label);
                idx++; 
            }
            //std::cout << "Getting feature descriptions" << std::endl;
            Board dummy;
            std::vector<FeaturesConv::FeatureDescription> ret;
            FeaturesConv::ConvertBoardToNN(dummy, ret);

            //std::cout << "Starting Training" << std::endl;
            
            std::ofstream outNet(argv[6]);
            //evaluator.Serialize(outNet);
            evaluator.TrainLoop(fens, mat_labels, 10, ret);
            evaluator.Serialize(outNet);
        }
/*
        std::ifstream epd_file(argv[2]);
        std::ifstream label_file(argv[3]);
        std::string fen;
        std::vector<std::string> fens;
        std::cout << "Reading FENS" << std::endl;
        while(std::getline(epd_file, fen))
        {
            fens.push_back(fen);
        } 
        std::cout << "Reading labels" << std::endl;
        
        std::string label;
        NNMatrixRM mat_labels = NNMatrixRM(fens.size(), 1);
        //std::vector<int> labels; 
        int idx = 0;
        while(std::getline(label_file, label)){
            //labels.push_back(stoi(label));
            mat_labels(idx, 0) = std::stoi(label);
            idx++; 
        }
        std::cout << "Getting feature descriptions" << std::endl;
        int epochs = std::stoi(argv[4]);
        Board dummy;
		std::vector<FeaturesConv::FeatureDescription> ret;
		FeaturesConv::ConvertBoardToNN(dummy, ret);

        std::cout << "Starting Training" << std::endl;
        
        std::ofstream outNet(argv[5]);
        evaluator.Serialize(outNet);
        evaluator.TrainLoop(fens, mat_labels, epochs, ret);
        evaluator.Serialize(outNet);
  */      
        return 0;
    }
    else if (argc >= 2 && std::string(argv[1]) == "conv_file")
    {
        InitializeSlowBlocking(evaluator, mevaluator);
        
        std::ifstream inFile(argv[2]);
        std::ofstream outFile(argv[3]);

        std::string fen;
        std::vector<std::string> fens;
        while(std::getline(inFile, fen))
        {
            fens.push_back(fen);
            std::cout << fen << std::endl;
            /*
            Board b(fen);
		    std::vector<FeaturesConv::FeatureDescription> ret;
		    FeaturesConv::ConvertBoardToNN(b, ret);
            std::stringstream ss;
            for (int i = 0; i < ret.size()-1; i++)
            {
                ss << ret[i].XieToString() << " ";

            }
            ss << ret[ret.size() - 1].XieToString();
            outFile << ss.str() << std::endl;
            std::cout << ss.str() << std::endl;
            std::cout << "****" << std::endl;
            */
        }

        std::vector<FeaturesConv::FeatureDescription> dummy(363);

        NNMatrixRM ret = evaluator.BoardsToFeatureRepresentation_(fens, dummy);
        for(int64_t row = 0; row < ret.rows(); ++row)
        {
            for(int64_t col = 0; col < ret.cols(); ++ col)
            {
                outFile << ret(row,col) << ' '; 
            }
            outFile << '\n';
        }


        inFile.close();
        outFile.close();

        return 0;

    }
	else if (argc >= 2 && std::string(argv[1]) == "mconv")
	{
		InitializeSlowBlocking(evaluator, mevaluator);

		if (argc < 3)
		{
			std::cout << "Usage: " << argv[0] << " mconv FEN" << std::endl;
			return 0;
		}

		std::stringstream ss;

		for (int i = 2; i < argc; ++i)
		{
			ss << argv[i] << ' ';
		}

		Board b(ss.str());

		MoveList moves;
		b.GenerateAllLegalMoves<Board::ALL>(moves);

		NNMatrixRM ret;

		FeaturesConv::ConvertMovesInfo convInfo;

		FeaturesConv::ConvertMovesToNN(b, convInfo, moves, ret);

		for (int64_t row = 0; row < ret.rows(); ++row)
		{
			for (int64_t col = 0; col < ret.cols(); ++col)
			{
				std::cout << ret(row, col) << ' ';
			}
			std::cout << std::endl;
		}

		return 0;
	}
	else if (argc >= 2 && std::string(argv[1]) == "bench")
	{
		InitializeSlowBlocking(evaluator, mevaluator);

		double startTime = CurrentTime();

		static const NodeBudget BenchNodeBudget = 64*1024*1024;

		Search::SyncSearchNodeLimited(Board("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"), BenchNodeBudget, backend.GetEvaluator(), backend.GetMoveEvaluator());
		Search::SyncSearchNodeLimited(Board("2r2rk1/pp3pp1/b2Pp3/P1Q4p/RPqN2n1/8/2P2PPP/2B1R1K1 w - - 0 1"), BenchNodeBudget, backend.GetEvaluator(), backend.GetMoveEvaluator());
		Search::SyncSearchNodeLimited(Board("8/1nr3pk/p3p1r1/4p3/P3P1q1/4PR1N/3Q2PK/5R2 w - - 0 1"), BenchNodeBudget, backend.GetEvaluator(), backend.GetMoveEvaluator());
		Search::SyncSearchNodeLimited(Board("5R2/8/7r/7P/5RPK/1k6/4r3/8 w - - 0 1"), BenchNodeBudget, backend.GetEvaluator(), backend.GetMoveEvaluator());
		Search::SyncSearchNodeLimited(Board("r5k1/2p2pp1/1nppr2p/8/p2PPp2/PPP2P1P/3N2P1/R3RK2 w - - 0 1"), BenchNodeBudget, backend.GetEvaluator(), backend.GetMoveEvaluator());
		Search::SyncSearchNodeLimited(Board("8/R7/8/1k6/1p1Bq3/8/4NK2/8 w - - 0 1"), BenchNodeBudget, backend.GetEvaluator(), backend.GetMoveEvaluator());

		std::cout << "Time: " << (CurrentTime() - startTime) << "s" << std::endl;

		return 0;
	}
	else if (argc >= 2 && std::string(argv[1]) == "check_bounds")
	{
		InitializeSlowBlocking(evaluator, mevaluator);

		if (argc < 3)
		{
			std::cout << "Usage: " << argv[0] << " check_bounds <EPD/FEN file>" << std::endl;
			return 0;
		}

		std::ifstream infile(argv[2]);

		if (!infile)
		{
			std::cerr << "Failed to open " << argv[2] << " for reading" << std::endl;
			return 1;
		}

		uint64_t passes = 0;
		uint64_t total = 0;
		float windowSizeTotal = 0.0f;

		std::string fen;
		std::vector<std::string> fens;
		while (std::getline(infile, fen))
		{
			fens.push_back(fen);
		}

		#pragma omp parallel
		{
			auto evaluatorCopy = evaluator;

			#pragma omp for
			for (size_t i = 0; i < fens.size(); ++i)
			{
				Board b(fens[i]);
				float windowSize = 0.0f;
				bool res = evaluatorCopy.CheckBounds(b, windowSize);

				#pragma omp critical(boundCheckAccum)
				{
					if (res)
					{
						++passes;
					}

					++total;

					windowSizeTotal += windowSize;
				}
			}
		}

		std::cout << passes << "/" << total << std::endl;
		std::cout << "Average window size: " << (windowSizeTotal / total) << std::endl;

		return 0;
	}
	else if (argc >= 2 && std::string(argv[1]) == "train_bounds")
	{
		InitializeSlowBlocking(evaluator, mevaluator);

		if (argc < 4)
		{
			std::cout << "Usage: " << argv[0] << " train_bounds <EPD/FEN file> <output net file>" << std::endl;
			return 0;
		}

		std::ifstream infile(argv[2]);

		if (!infile)
		{
			std::cerr << "Failed to open " << argv[2] << " for reading" << std::endl;
			return 1;
		}

		std::vector<FeaturesConv::FeatureDescription> featureDescriptions;
		Board dummyBoard;
		FeaturesConv::ConvertBoardToNN(dummyBoard, featureDescriptions);

		std::string line;
		std::vector<std::string> fens;
		while (std::getline(infile, line))
		{
			fens.push_back(line);
		}

		const size_t BlockSize = 256;
		const size_t PrintInterval = BlockSize * 100;

		for (size_t i = 0; i < (fens.size() - BlockSize); i += BlockSize)
		{
			if (i % PrintInterval == 0)
			{
				std::cout << i << "/" << fens.size() << std::endl;
			}

			std::vector<std::string> positions;

			for (size_t j = 0; j < BlockSize; ++j)
			{
				positions.push_back(fens[i + j]);
			}

			evaluator.TrainBounds(positions, featureDescriptions, 1.0f);
		}

		std::ofstream outfile(argv[3]);

		if (!outfile)
		{
			std::cerr << "Failed to open " << argv[3] << " for writing" << std::endl;
			return 1;
		}

		evaluator.Serialize(outfile);

		return 0;
	}
	else if (argc >= 2 && std::string(argv[1]) == "sample_internal")
	{
		// MUST UNCOMMENT "#define SAMPLING" in static move evaluator

		InitializeSlowBlocking(evaluator, mevaluator);

		if (argc < 4)
		{
			std::cout << "Usage: " << argv[0] << " sample_internal <EPD/FEN file> <output file>" << std::endl;
			return 0;
		}

		std::ifstream infile(argv[2]);
		std::ofstream outfile(argv[3]);

		if (!infile)
		{
			std::cerr << "Failed to open " << argv[2] << " for reading" << std::endl;
			return 1;
		}

		std::string fen;
		std::vector<std::string> fens;
		static const uint64_t maxPositions = 5000000;
		uint64_t numPositions = 0;
		while (std::getline(infile, fen) && numPositions < maxPositions)
		{
			fens.push_back(fen);
			++numPositions;
		}

		#pragma omp parallel
		{
			auto evaluatorCopy = evaluator;

			#pragma omp for
			for (size_t i = 0; i < fens.size(); ++i)
			{
                std::cout << i << std::endl;
				Board b(fens[i]);

				Search::SyncSearchNodeLimited(b, 1000, &evaluatorCopy, &gStaticMoveEvaluator, nullptr, nullptr);
			}
		}

		for (const auto &pos : gStaticMoveEvaluator.samples)
		{
			outfile << pos << std::endl;
		}

		return 0;
	}
	else if (argc >= 2 && std::string(argv[1]) == "label_bm")
	{
		InitializeSlowBlocking(evaluator, mevaluator);

		if (argc < 4)
		{
			std::cout << "Usage: " << argv[0] << " label_bm <EPD/FEN file> <output file>" << std::endl;
			return 0;
		}

		std::ifstream infile(argv[2]);
		std::ofstream outfile(argv[3]);

		if (!infile)
		{
			std::cerr << "Failed to open " << argv[2] << " for reading" << std::endl;
			return 1;
		}

		std::string fen;
		std::vector<std::string> fens;
		static const uint64_t maxPositions = 5000000;
		uint64_t numPositions = 0;
		while (std::getline(infile, fen) && numPositions < maxPositions)
		{
			Board b(fen);

			if (b.GetGameStatus() != Board::ONGOING)
			{
				continue;
			}

			fens.push_back(fen);
			++numPositions;
		}

		std::vector<std::string> bm(fens.size());

		uint64_t numPositionsDone = 0;

		double lastPrintTime = CurrentTime();
		size_t lastDoneCount = 0;

		#pragma omp parallel
		{
			auto evaluatorCopy = evaluator;

			#pragma omp for schedule(dynamic)
			for (size_t i = 0; i < fens.size(); ++i)
			{
				Board b(fens[i]);

				Search::SearchResult result = Search::SyncSearchNodeLimited(b, 100000, &evaluatorCopy, &gStaticMoveEvaluator, nullptr, nullptr);

				bm[i] = b.MoveToAlg(result.pv[0]);

				#pragma omp critical(numPositionsAndOutputFileUpdate)
				{
					++numPositionsDone;

					outfile << fens[i] << '\n';
					outfile << bm[i] << '\n';

					if (omp_get_thread_num() == 0)
					{
						double currentTime = CurrentTime();
						double timeDiff = currentTime - lastPrintTime;
						if (timeDiff > 1.0)
						{
							std::cout << numPositionsDone << '/' << fens.size() << std::endl;
							std::cout << "Positions per second: " << static_cast<double>(numPositionsDone - lastDoneCount) / timeDiff << std::endl;

							lastPrintTime = currentTime;
							lastDoneCount = numPositionsDone;
						}
					}
				}
			}
		}

		return 0;
	}
	else if (argc >= 2 && std::string(argv[1]) == "train_move_eval")
	{
		InitializeSlowBlocking(evaluator, mevaluator);

		if (argc < 4)
		{
			std::cout << "Usage: " << argv[0] << " train_move_eval <EPD/FEN file> <output file>" << std::endl;
			return 0;
		}

		std::ifstream infile(argv[2]);

		if (!infile)
		{
			std::cerr << "Failed to open " << argv[2] << " for reading" << std::endl;
			return 1;
		}

		std::cout << "Reading positions from " << argv[2] << std::endl;

		std::string fen;
		std::string bestMove;
		std::vector<std::string> fens;
		std::vector<std::string> bestMoves;
		static const uint64_t MaxPositions = 5000000;
		uint64_t numPositions = 0;
		while (std::getline(infile, fen) && std::getline(infile, bestMove) && numPositions < MaxPositions)
		{
			Board b(fen);

			if (b.GetGameStatus() != Board::ONGOING)
			{
				continue;
			}

			fens.push_back(fen);
			bestMoves.push_back(bestMove);

			++numPositions;
		}

		assert(bestMoves.size() == fens.size());

		// now we split a part of it out into a withheld test set
		size_t numTrainExamples = fens.size() * 0.9f;
		std::vector<std::string> fensTest(fens.begin() + numTrainExamples, fens.end());
		std::vector<std::string> bestMovesTest(bestMoves.begin() + numTrainExamples, bestMoves.end());

		static const uint64_t MaxTestingPositions = 10000;

		if (fensTest.size() > MaxTestingPositions)
		{
			fensTest.resize(MaxTestingPositions);
			bestMovesTest.resize(MaxTestingPositions);
		}

		fens.resize(numTrainExamples);
		bestMoves.resize(numTrainExamples);

		std::cout << "Num training examples: " << numTrainExamples << std::endl;
		std::cout << "Num testing examples: " << fensTest.size() << std::endl;

		std::cout << "Starting training" << std::endl;

		ANNMoveEvaluator meval(evaluator);

		meval.Train(fens, bestMoves);

		meval.Test(fensTest, bestMovesTest);

		std::ofstream outfile(argv[3]);

		meval.Serialize(outfile);

		return 0;
	}

	// we need a mutex here because InitializeSlow needs to print, and it may decide to
	// print at the same time as the main command loop (if the command loop isn't waiting)
	std::mutex coutMtx;

	coutMtx.lock();

	// do all the heavy initialization in a thread so we can reply to "protover 2" in time
	std::thread initThread(InitializeSlow, std::ref(evaluator), std::ref(mevaluator), std::ref(coutMtx));

	auto waitForSlowInitFunc = [&initThread, &coutMtx]() { coutMtx.unlock(); initThread.join(); coutMtx.lock(); };

	while (true)
	{
		std::string lineStr;

		coutMtx.unlock();
		std::getline(std::cin, lineStr);
		coutMtx.lock();

		std::stringstream line(lineStr);

		// we set usermove=1, so all commands from xboard start with a unique word
		std::string cmd;
		line >> cmd;

		// this is the list of commands we can process before initialization finished
		if (
			cmd != "xboard" &&
			cmd != "protover" &&
			cmd != "hard" &&
			cmd != "easy" &&
			cmd != "cores" &&
			cmd != "memory" &&
			cmd != "accepted" &&
			cmd != "rejected" &&
			initThread.joinable())
		{
			// wait for initialization to be done
			waitForSlowInitFunc();
		}

		if (cmd == "xboard") {} // ignore since we only support xboard mode anyways
		else if (cmd == "protover")
		{
			int32_t ver;
			line >> ver;

			if (ver >= 2)
			{
				std::string name = "Giraffe";
				if (gVersion != "")
				{
					name += " ";
					name += gVersion;
				}

				std::cout << "feature ping=1 setboard=1 playother=0 san=0 usermove=1 time=1 draw=0 sigint=0 sigterm=0 "
							 "reuse=1 analyze=1 myname=\"" << name << "\" variants=normal colors=0 ics=0 name=0 pause=0 nps=0 "
							 "debug=1 memory=0 smp=0 done=0" << std::endl;

				std::cout << "feature option=\"GaviotaTbPath -path .\"" << std::endl;

				std::cout << "feature done=1" << std::endl;
			}
		}
		else if (cmd == "accepted") {}