Ejemplo n.º 1
0
/* Trains the network with the backpropagation algorithm.
 */
FANN_EXTERNAL void FANN_API fann_train(struct fann *ann, fann_type * input,
									   fann_type * desired_output)
{
	fann_run(ann, input);

	fann_compute_MSE(ann, desired_output);

	fann_backpropagate_MSE(ann);

	fann_update_weights(ann);
}
Ejemplo n.º 2
0
float train_epoch_debug(struct fann *ann, struct fann_train_data* data, unsigned int iter)
{
	unsigned int i;
#if VERBOSE>=2
	static unsigned int j=0;
#endif

#if ! MIMO_FANN
	if (ann->prev_train_slopes==NULL)
		fann_clear_train_arrays(ann);
#endif

	fann_reset_MSE(ann);

	for(i = 0; i < data->num_data; i++)
	{
		fann_run(ann, data->input[i]);
		fann_compute_MSE(ann, data->output[i]);
		fann_backpropagate_MSE(ann);
#if ! MIMO_FANN
		fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
#endif

#if VERBOSE>=3
		printf("   ** %d:%d **-AFTER-DELTAS UPDATE-----------------------------------\n", iter, i);
		print_deltas(ann, j++);
#endif

	}
#if VERBOSE>=2
	printf("   ** %d **-BEFORE-WEIGHTS-UPDATE------------------------------------\n", iter);
	print_deltas(ann, j++);
#endif

#if ! MIMO_FANN
#if USE_RPROP
	fann_update_weights_irpropm(ann, 0, ann->total_connections);
#else
	fann_update_weights_batch(ann, data->num_data, 0, ann->total_connections);
#endif
#else /* MIMO_FANN */
	fann_update_weights(ann);
#endif

#if VERBOSE>=1
	printf("   ** %d **-AFTER-WEIGHTS-UPDATE-------------------------------------\n", iter);
	print_deltas(ann, j++);
#endif

	return fann_get_MSE(ann);
}
Ejemplo n.º 3
0
float train_epoch_incremental_mod(struct fann *ann, struct fann_train_data *data, vector< vector<fann_type> >& predicted_outputs)
{

	predicted_outputs.resize(data->num_data,vector<fann_type> (data->num_output));
	fann_reset_MSE(ann);

	for(unsigned int i = 0; i < data->num_data; ++i)
	{
		fann_type* temp_predicted_output=fann_run(ann, data->input[i]);
		for(unsigned int k=0;k<data->num_output;++k)
		{
			predicted_outputs[i][k]=temp_predicted_output[k];
		}

		fann_compute_MSE(ann, data->output[i]);

		fann_backpropagate_MSE(ann);

		fann_update_weights(ann);
	}

	return fann_get_MSE(ann);
}