ViFannTrain::ViFannTrain(const ViFannTrain &other) { mData = fann_create_train(other.mData->num_data, other.mData->num_input, other.mData->num_output); for(int i = 0; i < mData->num_data; ++i) { memcpy(mData->input[i], other.mData->input[i], sizeof(float) * mData->num_input); memcpy(mData->output[i], other.mData->output[i], sizeof(float) * mData->num_output); } }
/* * INTERNAL FUNCTION Reads training data from a file descriptor. */ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filename) { unsigned int num_input, num_output, num_data, i, j; unsigned int line = 1; struct fann_train_data *data; if(fscanf(file, "%u %u %u\n", &num_data, &num_input, &num_output) != 3) { fann_error(NULL, FANN_E_CANT_READ_TD, filename, line); return NULL; } line++; data = fann_create_train(num_data, num_input, num_output); if(data == NULL) { return NULL; } for(i = 0; i != num_data; i++) { for(j = 0; j != num_input; j++) { if(!fann_scanvalue(file, FANNSCANF, &data->input[i][j])) { fann_error(NULL, FANN_E_CANT_READ_TD, filename, line); fann_destroy_train(data); return NULL; } } line++; for(j = 0; j != num_output; j++) { if(!fann_scanvalue(file, FANNSCANF, &data->output[i][j])) { fann_error(NULL, FANN_E_CANT_READ_TD, filename, line); fann_destroy_train(data); return NULL; } } line++; } return data; }
/* * Creates training data from a callback function. */ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_from_callback(unsigned int num_data, unsigned int num_input, unsigned int num_output, void (FANN_API *user_function)( unsigned int, unsigned int, unsigned int, fann_type * , fann_type * )) { unsigned int i; struct fann_train_data *data = fann_create_train(num_data, num_input, num_output); if(data == NULL) { return NULL; } for( i = 0; i != num_data; i++) { (*user_function)(i, num_input, num_output, data->input[i], data->output[i]); } return data; }
bool ViFann::setWeights(const Weights &initialization, const qreal &minimum, const qreal &maximum) { if(mNetwork == NULL) return false; mWeights = initialization; if(initialization == Random) { fann_randomize_weights(mNetwork, minimum, maximum); mWeightsMinimum = minimum; mWeightsMaximum = maximum; } else if(initialization == WidrowNguyen) { // Create fake training set so that FANN can determine the min and max values fann_train_data *data = fann_create_train(1, 2, 1); data->input[0][0] = 1; data->input[0][1] = -1; fann_init_weights(mNetwork, data); fann_destroy_train(data); } return true; }
int main( int argc, char ** argv) { float mse=1000; unsigned int num_train=R_NUM; unsigned int num_test=T_NUM; struct fann_train_data* data ; unsigned int i; const float desired_error = (const float) E_DES; const unsigned int epochs_between_reports = N_EPR; unsigned int bitf_limit=0; unsigned int bitf=bitf_limit+1; struct fann *ann; #if MIMO_FANN printf("MIMO fann\n"); #else printf("Old fann\n"); #endif #ifdef USE_XOR_DATA if (argc<2) { printf("Error: please supply a data file\n"); return -1; } printf("Using %s\n", argv[1]); data=fann_read_train_from_file(argv[1]); #else printf("Generating training data\n"); data = fann_create_train(S_DIM, I_DIM, O_DIM); for ( i=0; i< S_DIM; i++) { f1(data, i); } #endif ann=setup_net(data); #if VERBOSE fann_print_parameters(ann); #endif for (i=0; mse>desired_error && i!=num_train && bitf>bitf_limit; i++) { #if VERBOSE mse=train_epoch_debug(ann, data, i); #else mse=fann_train_epoch(ann, data); #endif bitf=fann_get_bit_fail(ann); if ( !((i) % epochs_between_reports)) printf("Epochs %8d. Current error: %.10f. Bit fail: %u\n", i+(!i), mse, bitf); /*printf ("[ %7u ] MSE Error : %.10e ###################\n", i, mse);*/ } printf("Epochs %8d. Current error: %.10f. Bit fail: %u\n", i+(!i), mse, bitf); printf("Testing network. %f\n", fann_test_data(ann, data)); gettimeofday(&tv_start,NULL); for (i=0; i!=num_test; i++) fann_run_data(ann, data); gettimeofday(&tv_now,NULL); report("---",0); #if 1 printf("Trying to save network\n"); #if MIMO_FANN fann_save(ann, "saved_mimo.net"); fann_destroy(ann); ann=fann_create_from_file("saved_mimo.net"); fann_save(ann, "saved_mimo2.net"); fann_destroy(ann); #else fann_save(ann, "saved_old.net"); #endif #endif return 0; }
ViFannTrain::ViFannTrain(const int &dataCount, const int &inputs, const int &outputs) { mData = fann_create_train(dataCount, inputs, outputs); }