Exemplo n.º 1
0
// Train the net with all the data
void NeuralNet::train_net( int num_data, int num_input, float** input,int num_output, float ** output)
//void train_net(FANN::neural_net &net,unsigned int num_data, unsigned int num_input, fann_type **input,unsigned int num_output,fann_type **output)
{
  net = new FANN::neural_net();
  const float learning_rate = 0.07f;
  const unsigned int num_layers=3;
  int num_hidden=num_input/2;
  unsigned int layers[3]={num_input,num_hidden,num_output};
  net->create_standard_array(num_layers,layers); 
  net->set_learning_rate(learning_rate);
  net->set_activation_steepness_hidden(.1);
  net->set_activation_steepness_output(.1);
  
  net->set_activation_function_hidden(FANN::SIGMOID_SYMMETRIC_STEPWISE);
  net->set_activation_function_output(FANN::SIGMOID_SYMMETRIC_STEPWISE);

  //net.set_training_algorithm(FANN::TRAIN_INCREMENTAL);
  // Set additional properties such as the training algorithm
  //net.set_training_algorithm(FANN::TRAIN_QUICKPROP);

    const float desired_error = 0.001f;
    const unsigned int max_iterations = 1000;
    const unsigned int iterations_between_reports = 1000;

    FANN::training_data data;

    data.set_train_data(num_data, num_input, input, num_output, output);

    // Initialize and train the network with the data
    net->init_weights(data);

    net->train_on_data(data, max_iterations,0, desired_error);

    data.save_train("ddumbassfile.save");
}
Exemplo n.º 2
0
void Execution::initializeWeights()
{
    assert(net && "Invalid program state, net uninitialized");
    const int numInputs = net->get_num_input();
    if (trainingInputsSelection.isValid() == 0)
    {
        QMessageBox::warning(this, tr("Information"),
                             tr("Select inputs columns for Training in Execution tab to initialize weights"));
        return;
    }
    if (numInputs != trainingInputsSelection.numColumns())
    {
        QMessageBox::warning(this, tr("Information"),
                             tr("Selected inputs columns must match net inputs"));
        return;
    }

    double **inputPatterns = trainingInputsSelection.getData();
    FANN::training_data data;
    data.set_train_data(trainingInputsSelection.numRows(), numInputs, inputPatterns, 0, 0);
    net->init_weights(data);
}
Exemplo n.º 3
0
int main(int argc, char *argv[])
{
    if (argv[1][0] == 'r') {
        WAVFile inp(argv[2]);
        translate_wav(inp);
        return 0;
    }
    if (argc == 1 || argc % 2 != 1) {
        std::cout << "bad number of training examples\n";
        return -1;
    }

    int to_open = (argc - 1)/2;
    for (int i = 0; i < to_open; i++) {
        WAVFile inp(argv[2*i+1]);
        WAVFile out(argv[2*i+2]);
        add_training_sound(inp, out);
    }

    float *train_in[input_training.size()];
    float *train_out[output_training.size()];
    for (int i = 0; i < input_training.size(); i++) {
        train_in[i] = input_training[i];
        train_out[i] = output_training[i];
    }
    FANN::training_data training;
    training.set_train_data(input_training.size(), (samples_per_segment/2+1)*2,
            train_in, (samples_per_segment/2+1)*2, train_out);
    FANN::neural_net net; 
    const unsigned int layers[] = {(samples_per_segment/2+1)*2, (samples_per_segment/2+1)*2, (samples_per_segment/2+1)*2};
    net.create_standard_array(3, (unsigned int*)layers);
    net.set_activation_function_output(FANN::LINEAR);
    //net.set_activation_function_hidden(FANN::LINEAR);
    net.set_learning_rate(1.2f);
    net.train_on_data(training, 50000, 1, 3.0f);

    net.save("net.net");
}
Exemplo n.º 4
0
void Execution::runTraining()
{
    assert(net && "Illegal program state, netuninitialized");
    if (isTrainingSelectionValid() == 0)
    {
        return;
    }
    emit preparingToTrain(this);

    double **inputPatterns = trainingInputsSelection.getData();
    double **referenceOutputs = trainingReferenceSelection.getData();
    const int numPatterns = trainingInputsSelection.numRows();

    FANN::training_data data;
    data.set_train_data(numPatterns, net->get_num_input(), inputPatterns,
                        net->get_num_output(), referenceOutputs);

    resetErrorButton->setDisabled(true);
    for (int i = 0; i < maxEpochs; ++i)
    {
        Selection& s = trainingErrorSelection;
        std::cout << i << std::endl;
        double mse = net->train_epoch(data);
        if (s.isValid() && (i < s.numRows()))
        {
            client->writeToCell(s.getSheet(), s.getStartR() + i, s.getStartC(), mse);
        }
        if (mse < desiredMSE)
        {
            QMessageBox::information(this, tr("Information"),
                                     tr("Net has reached the desired MSE of ") + QString::number(mse) +
                                     tr(" after ") + QString::number(i) + tr(" epochs."));
            break;
        }
    }
    resetErrorButton->setEnabled(true);
}