Ejemplo n.º 1
0
/**
 * @method rspamd_fann:train(inputs, outputs)
 * Trains neural network with samples. Inputs and outputs should be tables of
 * equal size, each row in table should be N inputs and M outputs, e.g.
 *     {0, 1, 1} -> {0}
 * @param {table} inputs input samples
 * @param {table} outputs output samples
 * @return {number} number of samples learned
 */
static gint
lua_fann_train (lua_State *L)
{
#ifndef WITH_FANN
	return 0;
#else
	struct fann *f = rspamd_lua_check_fann (L, 1);
	guint ninputs, noutputs, j;
	fann_type *cur_input, *cur_output;
	gboolean ret = FALSE;

	if (f != NULL) {
		/* First check sanity, call for table.getn for that */
		ninputs = rspamd_lua_table_size (L, 2);
		noutputs = rspamd_lua_table_size (L, 3);

		if (ninputs != fann_get_num_input (f) ||
			noutputs != fann_get_num_output (f)) {
			msg_err ("bad number of inputs(%d, expected %d) and "
					"output(%d, expected %d) args for train",
					ninputs, fann_get_num_input (f),
					noutputs, fann_get_num_output (f));
		}
		else {
			cur_input = g_malloc (ninputs * sizeof (fann_type));

			for (j = 0; j < ninputs; j ++) {
				lua_rawgeti (L, 2, j + 1);
				cur_input[j] = lua_tonumber (L, -1);
				lua_pop (L, 1);
			}

			cur_output = g_malloc (noutputs * sizeof (fann_type));

			for (j = 0; j < noutputs; j++) {
				lua_rawgeti (L, 3, j + 1);
				cur_output[j] = lua_tonumber (L, -1);
				lua_pop (L, 1);
			}

			fann_train (f, cur_input, cur_output);
			g_free (cur_input);
			g_free (cur_output);

			ret = TRUE;
		}
	}

	lua_pushboolean (L, ret);

	return 1;
#endif
}
Ejemplo n.º 2
0
template<bool msg> AnnDef *Ann_GetAnnDef( int i, unsigned int inc, unsigned int outc )
{
    AnnDef *def = getAnnDef(i);
    if(!def) {
        if(msg) Print("Could not get ANN at index %i\n", i);
        return 0;
    }
    struct fann *ann = def->_ann;
    if( inc != fann_get_num_input(ann) ) {
        if(msg) Print( "Error: Input count mismatch ugen: %i / ann: %i\n", inc, fann_get_num_input(ann) );
        return 0;
    }
    if( outc != fann_get_num_output(ann) ) {
        if(msg) Print( "Error: Output count mismatch ugen: %i / ann: %i\n", outc, fann_get_num_output(ann) );
        return 0;
    }
    return def;
}
Ejemplo n.º 3
0
/*! ann:__tostring()
 *# Converts a neural net to a string for Lua's virtual machine
 *x print(ann)
 *-
 */
static int ann_tostring(lua_State *L)
{
	struct fann **ann;

	ann = luaL_checkudata(L, 1, FANN_METATABLE);
	luaL_argcheck(L, ann != NULL, 1, "'neural net' expected");

	lua_pushfstring(L, "[[FANN neural network: %d %d %d]]", fann_get_num_input(*ann),
					fann_get_num_output(*ann), fann_get_total_neurons(*ann));
	return 1;
}
Ejemplo n.º 4
0
/*! ann:run(input1, input2, ..., inputn)
 *# Evaluates the neural network for the given inputs.
 *x xor = ann:run(-1, 1)
 *-
 */
static int ann_run(lua_State *L)
{
	struct fann **ann;

	int nin, nout, i;
	fann_type *input, *output;

	ann = luaL_checkudata(L, 1, FANN_METATABLE);
	luaL_argcheck(L, ann != NULL, 1, "'neural net' expected");

	nin = lua_gettop(L) - 1;
	if(nin != fann_get_num_input(*ann))
		luaL_error(L, "wrong number of inputs: expected %d, got %d", fann_get_num_input(*ann), nin);

	nout = fann_get_num_output(*ann);

#ifdef FANN_VERBOSE
	printf("Evaluating neural net: %d inputs, %d outputs\n", nin, nout);
#endif

	input = lua_newuserdata(L, nin*(sizeof *input));

	for(i = 0; i < nin; i++)
	{
		input[i] = luaL_checknumber(L, i + 2);
#ifdef FANN_VERBOSE
		printf("Input %d's value is %f\n", i, input[i]);
#endif
	}

	output = fann_run(*ann, input);
	for(i = 0; i < nout; i++)
	{
#ifdef FANN_VERBOSE
	printf("Output %d's value is %f\n", i, output[i]);
#endif
		lua_pushnumber(L, output[i]);
	}

	return nout;
}
Ejemplo n.º 5
0
void vTrainThread::run(){
    results.clear();
    const unsigned int num_input = fann_get_num_input(neural);
    //Готовим выборку для шага обучения

    float *data =new float[num_input];
    for(int i = 0;i<steps; i++ ){
        struct train_result step;
        memset(&step, 0, sizeof(train_result));

        signal->logMessage(DEBUG, QString("Step%1").arg(i));
        memset(data, 0, num_input*sizeof(float));
        float desired_output = 0.0;

        QByteArray input = buffer->getBuffer(num_input);
        QByteArray diff = buffer->getDiff(num_input);

        for(int i = 0; i<input.size(); i++){
            data[i] = static_cast<float>(input.at(i));
            desired_output += abs(static_cast<float>(diff.at(i)));
        }

        signal->logMessage(DEBUG, QString("   Diff Sum:     %1").arg(desired_output));

        desired_output /=255;

        step.need_result = desired_output;

        signal->logMessage(DEBUG, QString("   need output:  %1").arg(desired_output));
        float* var = fann_run(neural, data);

        step.output_before_train = *var;
        signal->logMessage(DEBUG, QString("   value before: %1").arg(*var));


        fann_train(neural, data, &desired_output);

        var = fann_run(neural, data);

        step.output_after_train = desired_output;

        signal->logMessage(DEBUG, QString("   value after:  %1").arg(*var));

        step.error1 = (step.need_result == 0 && step.output_before_train != 0) ?  true : false;

        step.error2 = (step.need_result != 0 && step.output_before_train == 0) ?  1 : 0;

        results.append(step);

    }

}
Ejemplo n.º 6
0
/***
 * @method rspamd_fann:get_inputs()
 * Returns number of inputs for neural network
 * @return {number} number of inputs
 */
static gint
lua_fann_get_inputs (lua_State *L)
{
#ifndef WITH_FANN
	return 0;
#else
	struct fann *f = rspamd_lua_check_fann (L, 1);

	if (f != NULL) {
		lua_pushnumber (L, fann_get_num_input (f));
	}
	else {
		lua_pushnil (L);
	}

	return 1;
#endif
}
Ejemplo n.º 7
0
/**
 * @method rspamd_fann:train_threaded(inputs, outputs, callback, event_base, {params})
 * Trains neural network with batch of samples. Inputs and outputs should be tables of
 * equal size, each row in table should be N inputs and M outputs, e.g.
 *     {{0, 1, 1}, ...} -> {{0}, {1} ...}
 * @param {table} inputs input samples
 * @param {table} outputs output samples
 * @param {callback} function that is called when train is completed
 */
static gint
lua_fann_train_threaded (lua_State *L)
{
#ifndef WITH_FANN
	return 0;
#else
	struct fann *f = rspamd_lua_check_fann (L, 1);
	guint ninputs, noutputs, ndata, i, j;
	struct lua_fann_train_cbdata *cbdata;
	struct event_base *ev_base = lua_check_ev_base (L, 5);
	GError *err = NULL;
	const guint max_epochs_default = 1000;
	const gdouble desired_mse_default = 0.0001;

	if (f != NULL && lua_type (L, 2) == LUA_TTABLE &&
			lua_type (L, 3) == LUA_TTABLE && lua_type (L, 4) == LUA_TFUNCTION &&
			ev_base != NULL) {
		/* First check sanity, call for table.getn for that */
		ndata = rspamd_lua_table_size (L, 2);
		ninputs = fann_get_num_input (f);
		noutputs = fann_get_num_output (f);
		cbdata = g_malloc0 (sizeof (*cbdata));
		cbdata->L = L;
		cbdata->f = f;
		cbdata->train = rspamd_fann_create_train (ndata, ninputs, noutputs);
		lua_pushvalue (L, 4);
		cbdata->cbref = luaL_ref (L, LUA_REGISTRYINDEX);

		if (rspamd_socketpair (cbdata->pair, 0) == -1) {
			msg_err ("cannot open socketpair: %s", strerror (errno));
			cbdata->pair[0] = -1;
			cbdata->pair[1] = -1;
			goto err;
		}

		for (i = 0; i < ndata; i ++) {
			lua_rawgeti (L, 2, i + 1);

			if (rspamd_lua_table_size (L, -1) != ninputs) {
				msg_err ("invalid number of inputs: %d, %d expected",
						rspamd_lua_table_size (L, -1), ninputs);
				goto err;
			}

			for (j = 0; j < ninputs; j ++) {
				lua_rawgeti (L, -1, j + 1);
				cbdata->train->input[i][j] = lua_tonumber (L, -1);
				lua_pop (L, 1);
			}

			lua_pop (L, 1);
			lua_rawgeti (L, 3, i + 1);

			if (rspamd_lua_table_size (L, -1) != noutputs) {
				msg_err ("invalid number of outputs: %d, %d expected",
						rspamd_lua_table_size (L, -1), noutputs);
				goto err;
			}

			for (j = 0; j < noutputs; j++) {
				lua_rawgeti (L, -1, j + 1);
				cbdata->train->output[i][j] = lua_tonumber (L, -1);
				lua_pop (L, 1);
			}
		}

		cbdata->max_epochs = max_epochs_default;
		cbdata->desired_mse = desired_mse_default;

		if (lua_type (L, 5) == LUA_TTABLE) {
			rspamd_lua_parse_table_arguments (L, 5, NULL,
					"max_epochs=I;desired_mse=N",
					&cbdata->max_epochs, &cbdata->desired_mse);
		}

		/* Now we can call training in a separate thread */
		rspamd_socket_nonblocking (cbdata->pair[0]);
		event_set (&cbdata->io, cbdata->pair[0], EV_READ, lua_fann_thread_notify,
				cbdata);
		event_base_set (ev_base, &cbdata->io);
		/* TODO: add timeout */
		event_add (&cbdata->io, NULL);
		cbdata->t = rspamd_create_thread ("fann train", lua_fann_train_thread,
				cbdata, &err);

		if (cbdata->t == NULL) {
			msg_err ("cannot create training thread: %e", err);

			if (err) {
				g_error_free (err);
			}

			goto err;
		}
	}
	else {
		return luaL_error (L, "invalid arguments");
	}

	return 0;

err:
	if (cbdata->pair[0] != -1) {
		close (cbdata->pair[0]);
	}
	if (cbdata->pair[1] != -1) {
		close (cbdata->pair[1]);
	}

	fann_destroy_train (cbdata->train);
	luaL_unref (L, LUA_REGISTRYINDEX, cbdata->cbref);
	g_free (cbdata);
	return luaL_error (L, "invalid arguments");
#endif
}