Пример #1
0
void run_train_app(WINDOW* wnd)
{
    static int already_run = 0;
    
    if (already_run) {
	wprintf(&shell_wnd, "Train application already running.\n\n");
	return;
    }
    
    already_run = 1;
    init_train(wnd);
}
Пример #2
0
int model::read_data()
{
	utils::randomize();

	// read training data
	trngdata = new dataset;
	if (trngdata->read_data(ddir + dfile, &word2id))
	{
		std::cout << "Fail to read training data!\n";
		return 1;
	}

	// according to testing type initialise
	if (testing_type == NO_TEST || testing_type == SELF_TEST)
	{
		M = trngdata->M;
		V = trngdata->V;

		// randomly initialise model variables for training
		std::cout << "Now randomly initialising model variables for training" << std::endl;
		if (init_train())
		{
			std::cout << "Error: Failed to allocate model variables!" << std::endl;
			return 1;
		}
	}
	else if (testing_type == SEPARATE_TEST)
	{
		// read held-out testing data
		testdata = new dataset;
		if (testdata->read_data(ddir + tfile, &word2id))
		{
			std::cout << "Error: Failed to read training corpus!" << std::endl;
			return 1;
		}
		M = trngdata->M;
		test_M = testdata->M;
		V = testdata->V;
		trngdata->V = V;

		// randomly initialise model variables for training
		std::cout << "Now randomly initialising model variables for training" << std::endl;
		if (init_train())
		{
			std::cout << "Error: Failed to allocate model variables!" << std::endl;
			return 1;
		}
		// initialise aux variables for testing
		if (init_test())
		{
			std::cout << "Error: Failed to initialise testing variables!" << std::endl;
			return 1;
		}
	}

	// write word map to file
	if (dataset::write_wordmap(mdir + "wordmap.txt", &word2id))
	{
		return 1;
	}

	// construct the reverse map (currently stupidly by reading back)
	for (auto w : word2id)
	{
		id2word[w.second] = w.first;
	}

	return 0;
}
Пример #3
0
int trainCommand(char* param1, char* param2)
{
	if(stringCompare(param1, "clear") == 0 && stringCompare(param2, "") == 0)
			{
				train_clear_mem_buffer();
				return 0;
			}
			//	train speed setting
			else if(stringCompare(param1, "speed") == 0)
			{
				if(param2[1] == 0 && param2[0] >= '0' && param2[0] <= '5')
				{
					set_train_speed(&param2[0]);
					return 0;
				}
				else
				{
					wprintf(&shell_wnd, "Invalid Speed Selection\n");
					return -1;
				}
			}
			else if(stringCompare(param1, "switch") == 0)
			{
				if(param2[2] != 0)
				{
					wprintf(&shell_wnd, "Invalid switch command format\n");
					wprintf(&shell_wnd, "Must be of form train switch #C\n");
					wprintf(&shell_wnd, "Where # is [1-9] and C is 'R' or 'G'\n");	
					return -1;
				}
				if(param2[0] < '1' || param2[0] > '9')
				{
					wprintf(&shell_wnd, "Invalid switch identifier\n");	
					return -1;
				}
				if(param2[1] != 'G' && param2[1] != 'R')
				{
					wprintf(&shell_wnd, "Invalid switch position\n");	
					return -1;
				}
				train_set_switch(param2[0], param2[1]);
				return 0;
			}
			else if(stringCompare(param1, "see") == 0)
			{
				//	Only allow second parameter length of 2
				if(strlen(param2) > 2 || param2[0] == 0)
				   {
						wprintf(&shell_wnd, "Invalid parameter 2 length\n");	
						return -1;
				   }
				   else
				   {
					   	if(get_status_of_contact(param2) == 0)
						{
							wprintf(&shell_wnd, "train not detected\n");
						}
					   	else
					  	{
					   		wprintf(&shell_wnd, "train found on track ");
					   		wprintf(&shell_wnd, param2);
					   		wprintf(&shell_wnd, "\n");
					   	}
					   return 0;
				   }
			}
			else if(stringCompare(param1, "rev") == 0)
			{
				wprintf(&shell_wnd, "Reversing train direction\n");
				train_switch_directions();
				//switch_train_direction(&train_wnd);
				return 0;
			}
			else if(stringCompare(param1, "run") == 0 && stringCompare(param2, "") == 0)
			{
				wprintf(&shell_wnd, "Train Executing\n");
				trainRunning = 0;
				init_train(&train_wnd);
				return 0;
			}
			else
			{
				wprintf(&shell_wnd, "See 'help' for list of train commands\n");	
				return -1;
			}
	return 1;
}
Пример #4
0
int GisTrainer::train (
        DataReader& data_reader, LinearModel& model,
        int iter, float tol, float sigma2) {

    int ret = 0;
    int correct_num = 0;
    float log_likelihood = 0.0;
    char  file_name[MAX_PATH_LEN];

    if (sigma2 < 0.0) {
        log_warn ("gauss prior should be greater than zero.");
        return -1;
    }

    ret = build_param(data_reader, model);
    if (ret != 0) {
        log_warn("build parameter failed.");
        return -1;
    }

    ret = init_train(data_reader);
    if (ret != 0) {
        log_warn("init train failed.");
        return -1;
    }

    log_notice("start GIS iterations...");
    log_notice("number of feature:      %d.", _feat_num);
    log_notice("number of label:        %d.", _label_num);
    log_notice("Tolerance:              %E.", tol);
    log_notice("Gaussian Penalty:       %s", (sigma2?"on":"off"));
    log_notice("objective: min sum {-log p(y|x)} + 1/(2*sigma2)*||w||^2");


    log_notice("iters   loglikelihood    training accuracy");
    log_notice("==========================================");

    for (int cur_iter=0; cur_iter<iter; cur_iter++) {

        ret = cmpt_estimated(data_reader, model, correct_num, log_likelihood);
        if (ret != 0) {
            log_warn("compute estimated failed.");
            return -1;
        }
        log_notice("%3d\t%f\t  %.3f%%", cur_iter, log_likelihood/_tot_event_count, 
                        (float)correct_num/_tot_event_count*100);
        

        //update parameter
        if (sigma2) {
            float* weight_mat = model._weight_mat;
            float delta = 0.0;

            for (int feat_id=0; feat_id<_feat_num; feat_id++) {
                for (int label_id=0; label_id<_label_num; label_id++) {
                    ret = newton(_estimated[feat_id*_label_num + label_id],
                                _observed[feat_id*_label_num + label_id],
                                weight_mat[feat_id*_label_num + label_id],
                                sigma2, delta);
                    if (ret != 0) {
                        log_warn("newton method failed.");
                        return -1;
                    }

                    weight_mat[feat_id*_label_num + label_id] += delta;
                }
            }
        } else {

            float* weight_mat = model._weight_mat;
            float  log_observed  = 0.0;
            float  log_estimated = 0.0;

            for (int feat_id=0; feat_id<_feat_num; feat_id++) {
                for (int label_id=0; label_id<_label_num; label_id++) {
                    //unseen feature
                    if (_observed[feat_id*_label_num + label_id] == 0.0) {
                        continue;
                    }

                    log_observed  = log(_observed[feat_id*_label_num + label_id]);

                    log_estimated = _estimated[feat_id*_label_num + label_id]==0 ? LOG_ZERO 
                                    : log(_estimated[feat_id*_label_num + label_id]);

                    weight_mat[feat_id*_label_num + label_id] += (log_observed-log_estimated)/_correct_constant;
                }
            }
        }

    }

    log_notice ("train by gis_trainer success.");
    return 0;
}