// Train the net with all the data void NeuralNet::train_net( int num_data, int num_input, float** input,int num_output, float ** output) //void train_net(FANN::neural_net &net,unsigned int num_data, unsigned int num_input, fann_type **input,unsigned int num_output,fann_type **output) { net = new FANN::neural_net(); const float learning_rate = 0.07f; const unsigned int num_layers=3; int num_hidden=num_input/2; unsigned int layers[3]={num_input,num_hidden,num_output}; net->create_standard_array(num_layers,layers); net->set_learning_rate(learning_rate); net->set_activation_steepness_hidden(.1); net->set_activation_steepness_output(.1); net->set_activation_function_hidden(FANN::SIGMOID_SYMMETRIC_STEPWISE); net->set_activation_function_output(FANN::SIGMOID_SYMMETRIC_STEPWISE); //net.set_training_algorithm(FANN::TRAIN_INCREMENTAL); // Set additional properties such as the training algorithm //net.set_training_algorithm(FANN::TRAIN_QUICKPROP); const float desired_error = 0.001f; const unsigned int max_iterations = 1000; const unsigned int iterations_between_reports = 1000; FANN::training_data data; data.set_train_data(num_data, num_input, input, num_output, output); // Initialize and train the network with the data net->init_weights(data); net->train_on_data(data, max_iterations,0, desired_error); data.save_train("ddumbassfile.save"); }
void Execution::initializeWeights() { assert(net && "Invalid program state, net uninitialized"); const int numInputs = net->get_num_input(); if (trainingInputsSelection.isValid() == 0) { QMessageBox::warning(this, tr("Information"), tr("Select inputs columns for Training in Execution tab to initialize weights")); return; } if (numInputs != trainingInputsSelection.numColumns()) { QMessageBox::warning(this, tr("Information"), tr("Selected inputs columns must match net inputs")); return; } double **inputPatterns = trainingInputsSelection.getData(); FANN::training_data data; data.set_train_data(trainingInputsSelection.numRows(), numInputs, inputPatterns, 0, 0); net->init_weights(data); }
int main(int argc, char *argv[]) { if (argv[1][0] == 'r') { WAVFile inp(argv[2]); translate_wav(inp); return 0; } if (argc == 1 || argc % 2 != 1) { std::cout << "bad number of training examples\n"; return -1; } int to_open = (argc - 1)/2; for (int i = 0; i < to_open; i++) { WAVFile inp(argv[2*i+1]); WAVFile out(argv[2*i+2]); add_training_sound(inp, out); } float *train_in[input_training.size()]; float *train_out[output_training.size()]; for (int i = 0; i < input_training.size(); i++) { train_in[i] = input_training[i]; train_out[i] = output_training[i]; } FANN::training_data training; training.set_train_data(input_training.size(), (samples_per_segment/2+1)*2, train_in, (samples_per_segment/2+1)*2, train_out); FANN::neural_net net; const unsigned int layers[] = {(samples_per_segment/2+1)*2, (samples_per_segment/2+1)*2, (samples_per_segment/2+1)*2}; net.create_standard_array(3, (unsigned int*)layers); net.set_activation_function_output(FANN::LINEAR); //net.set_activation_function_hidden(FANN::LINEAR); net.set_learning_rate(1.2f); net.train_on_data(training, 50000, 1, 3.0f); net.save("net.net"); }
void Execution::runTraining() { assert(net && "Illegal program state, netuninitialized"); if (isTrainingSelectionValid() == 0) { return; } emit preparingToTrain(this); double **inputPatterns = trainingInputsSelection.getData(); double **referenceOutputs = trainingReferenceSelection.getData(); const int numPatterns = trainingInputsSelection.numRows(); FANN::training_data data; data.set_train_data(numPatterns, net->get_num_input(), inputPatterns, net->get_num_output(), referenceOutputs); resetErrorButton->setDisabled(true); for (int i = 0; i < maxEpochs; ++i) { Selection& s = trainingErrorSelection; std::cout << i << std::endl; double mse = net->train_epoch(data); if (s.isValid() && (i < s.numRows())) { client->writeToCell(s.getSheet(), s.getStartR() + i, s.getStartC(), mse); } if (mse < desiredMSE) { QMessageBox::information(this, tr("Information"), tr("Net has reached the desired MSE of ") + QString::number(mse) + tr(" after ") + QString::number(i) + tr(" epochs.")); break; } } resetErrorButton->setEnabled(true); }