Esempio n. 1
0
bool ViFann::setStructure(const Type &type, const QList<int> &neurons, const qreal &connectionRate)
{
	#ifdef GPU
		if(type != Standard)
		{
			LOG("The GPU version of FANN currently doesn't support shortcut, sparse or cascade networks.", QtFatalMsg);
			exit(-1);
		}
	#endif
	clear();

	mType = type;
	mInputCount = neurons.first();
	mOutputCount = neurons.last();
	mNeurons.clear();
	for(mI = 0; mI < neurons.size(); ++mI)
	{
		if(neurons[mI] != 0) mNeurons.append(neurons[mI]);
	}

	if(mInput == NULL) delete [] mInput;
	if(mOutput == NULL) delete [] mOutput;
	mInput = new float[mInputCount];
	mOutput = new float[mOutputCount];

	unsigned int layers = mNeurons.size();
	unsigned int layerNeurons[layers];
	for(mI = 0; mI < layers; ++mI) layerNeurons[mI] = mNeurons[mI];

	if(type == Standard) mNetwork = fann_create_standard_array(layers, layerNeurons);
	#ifndef GPU
		else if(type == Sparse)
		{
			mNetwork = fann_create_sparse_array(connectionRate, layers, layerNeurons);
			mConnectionRate = connectionRate;
		}
		else if(type == Shortcut) mNetwork = fann_create_shortcut_array(layers, layerNeurons);
	#endif
	else return false;

	fann_set_train_stop_function(mNetwork, FANN_STOPFUNC_MSE);

	if(ENABLE_CALLBACK)
	{
		fann_set_callback(mNetwork, &ViFann::trainCallback);
		mMseTotal.clear();
		mMseCount = 0;
	}

	return true;
}
Esempio n. 2
0
int main(int argc,char **argv)
{
    unlink(histfile);
    srand ( time ( NULL ) );
    // printf ( "Reading data.\n" );
    train_data = fann_read_train_from_file ( "train.dat" );
    test_data = fann_read_train_from_file ( "test.dat" );
//   signal ( 2, sig_term );

    //  fann_scale_train_data ( train_data, 0, 1.54 );
    // fann_scale_train_data ( test_data, 0, 1.54 );
    //cln_test_data=fann_duplicate_train_data(test_data);
    cln_train_data=fann_duplicate_train_data(train_data);


    printf ( "Creating cascaded network.\n" );
    ann =
        fann_create_shortcut ( 2, fann_num_input_train_data ( train_data ),
                               fann_num_output_train_data ( train_data ) );
    fann_set_training_algorithm ( ann, FANN_TRAIN_RPROP );
    fann_set_activation_function_hidden ( ann, FANN_SIGMOID );
    fann_set_activation_function_output ( ann, FANN_SIGMOID);
    fann_set_train_error_function ( ann, FANN_ERRORFUNC_LINEAR );

    //  if (fann_set_scaling_params(ann, train_data,-1.0f,1.0f,0.0f, 1.0f)==-1)
    //    printf("set scaling error: %s\n",fann_get_errno((struct fann_error*)ann));

    //    fann_scale_train_input(ann,train_data);
    // fann_scale_output_train_data(train_data,0.0f,1.0f);
//	   fann_scale_input_train_data(train_data, -1.0,1.0f);
    // fann_scale_output_train_data(test_data,-1.0f,1.0f);
    // fann_scale_input_train_data(test_data, -1.0,1.0f);
//fann_scale_train(ann,train_data);
    //  fann_scale_train(ann,weight_data);
    //  fann_scale_train(ann,test_data);
    /*
     * fann_set_cascade_output_change_fraction(ann, 0.1f);
     *  ;
     * fann_set_cascade_candidate_change_fraction(ann, 0.1f);
     *
     */


    //  fann_set_cascade_output_stagnation_epochs ( ann, 180 );

    //fann_set_cascade_weight_multiplier ( ann, ( fann_type ) 0.1f );


    fann_set_callback ( ann, cascade_callback );
    if ( !multi )
    {

        /*  */
        //  steepness[0] = 0.22;
        steepness[0] = 0.9;
        steepness[1] = 1.0;

        /*
         * steepness[1] = 0.55;
         *  ;
         * steepness[1] = 0.33;
         *  ;
         * steepness[3] = 0.11;
         *  ;
         * steepness[1] = 0.01;
         *
         */

        /*
         *  steepness = 0.5;
         *
         */
        // fann_set_cascade_activation_steepnesses ( ann, steepness, 2);

        /*
         * activation = FANN_SIN_SYMMETRIC;
         */

        /*
         * activation[0] = FANN_SIGMOID;
         *
         */
        activation[0] = FANN_SIGMOID;

        /*
         * activation[2] = FANN_ELLIOT_SYMMETRIC;
         *
         */
        activation[1] = FANN_LINEAR_PIECE;

        /*
         * activation[4] = FANN_GAUSSIAN_SYMMETRIC;
         *  ;
         * activation[5] = FANN_SIGMOID;
         *
         */
        activation[2] = FANN_ELLIOT;
        activation[3] = FANN_COS;
        /*
         *
         *
         */
        activation[4] = FANN_SIN;
        fann_set_cascade_activation_functions ( ann, activation, 5);
        /*   fann_set_cascade_num_candidate_groups ( ann,
                                                  fann_num_input_train_data
                                                  ( train_data ) ); */

    }
    else
    {

        /*
         * fann_set_cascade_activation_steepnesses(ann, &steepness, 0.75);
         *
         */
        // fann_set_cascade_num_candidate_groups ( ann, 1 );

    }

    /* TODO: weight mult > 0.01 */
    /*  if ( training_algorithm == FANN_TRAIN_QUICKPROP )
      {
          fann_set_learning_rate ( ann, 0.35f );


      }
      else
      {
          fann_set_learning_rate ( ann, 0.7f );

      }
      fann_set_bit_fail_limit ( ann, ( fann_type ) 0.9f );*/

    /*
     * fann_set_train_stop_function(ann, FANN_STOPFUNC_BIT);
     *
     */

    //fann_scale_output_train_data(train_data,0.0f,1.0f);
    //fann_scale_input_train_data(train_data, -1.0f,1.0f);
//	fann_scale_output_train_data(test_data, 0.0f,1.0f);
    //fann_scale_input_train_data(test_data, -1.0f,1.0f);

    // fann_randomize_weights ( ann, -0.2f, 0.2f );
    fann_init_weights ( ann, train_data );



    printf ( "Training network.\n" );
    fann_cascadetrain_on_data ( ann, train_data, max_neurons,
                                1, desired_error );
    fann_print_connections ( ann );
    mse_train = fann_test_data ( ann, train_data );
    bit_fail_train = fann_get_bit_fail ( ann );
    mse_test = fann_test_data ( ann, test_data );
    bit_fail_test = fann_get_bit_fail ( ann );
    printf
    ( "\nTrain error: %.08f, Train bit-fail: %d, Test error: %.08f, Test bit-fail: %d\n\n",
      mse_train, bit_fail_train, mse_test, bit_fail_test );

    printf ( "Saving cascaded network.\n" );
    fann_save ( ann, "cascaded.net" );
    //  printf ( "Cleaning up.\n" );
    fann_destroy_train ( train_data );
    fann_destroy_train ( test_data );
    fann_destroy ( ann );
    return 0;

}