Exemplo n.º 1
0
 void mlp::set_activation_function(int activation_function, mlp_layer layer)
 {
     if (grt_mlp.validateActivationFunction(activation_function) == false)
     {
         flext::error("activation function %d is invalid, hint should be between 0-%d", activation_function, GRT::Neuron::NUMBER_OF_ACTIVATION_FUNCTIONS - 1);
         return;
     }
     
     GRT::Neuron::ActivationFunctions activation_function_ = (GRT::Neuron::ActivationFunctions)activation_function;
     
     switch (layer)
     {
         case LAYER_INPUT:
             input_activation_function = activation_function_;
             break;
         case LAYER_HIDDEN:
             hidden_activation_function = activation_function_;
             break;
         case LAYER_OUTPUT:
             output_activation_function = activation_function_;
             break;
         default:
             ml::error("no activation function for layer: " + std::to_string(layer));
             return;
     }
     post("activation function set to " + grt_mlp.activationFunctionToString(activation_function_));
 }
Exemplo n.º 2
0
    // Methods
    // NOTE: ANN is special since it supports both regression and classification, we therefore override these methods
    void ann::train()
    {
        const data_type data_type = get_data_type();
        
        GRT::UINT numSamples = data_type == LABELLED_CLASSIFICATION ? classification_data.getNumSamples() : regression_data.getNumSamples();
        
        if (numSamples == 0)
        {
            flext::error("no observations added, use 'add' to add training data");
            return;
        }
        
        bool success = false;
        
        if (data_type == LABELLED_CLASSIFICATION)
        {
            grt_ann.init(
                     classification_data.getNumDimensions(),
                     num_hidden_neurons,
                     classification_data.getNumClasses(),
                     input_activation_function,
                     hidden_activation_function,
                     output_activation_function
                     );
            success = grt_ann.train(classification_data);
        }
        else if (data_type == LABELLED_REGRESSION)
        {
            grt_ann.init(
                     regression_data.getNumInputDimensions(),
                     num_hidden_neurons,
                     regression_data.getNumTargetDimensions(),
                     input_activation_function,
                     hidden_activation_function,
                     output_activation_function
                     );
            success = grt_ann.train(regression_data);
        }

        if (!success)
        {
            flext::error("training failed");
        }
        
        t_atom a_success;
        
        SetInt(a_success, success);
        ToOutAnything(1, get_s_train(), 1, &a_success);
    }
Exemplo n.º 3
0
 void ann::error()
 {
     if (!grt_ann.getTrained())
     {
         flext::error("model not yet trained, send the \"train\" message to train");
         return;
     }
             
     float error_f = grt_ann.getTrainingError();
     t_atom error_a;
     
     SetFloat(error_a, error_f);
     
     ToOutAnything(0, get_s_error(), 1, &error_a);
                   
 }
Exemplo n.º 4
0
 void ann::set_null_rejection(bool null_rejection)
 {
     bool success = grt_ann.setNullRejection(null_rejection);
     
     if (success == false)
     {
         flext::error("unable to set null_rejection");
     }
 }
Exemplo n.º 5
0
 void ann::set_gamma(float gamma)
 {
     bool success = grt_ann.setGamma(gamma);
     
     if (success == false)
     {
         flext::error("unable to set gamma");
     }
 }
Exemplo n.º 6
0
 void ann::set_momentum(float momentum)
 {
     bool success = grt_ann.setMomentum(momentum);
     
     if (success == false)
     {
         flext::error("unable to set momentum, hint: should be between 0-1");
     }
 }
Exemplo n.º 7
0
 void ann::set_training_rate(float training_rate)
 {
     bool success = grt_ann.setTrainingRate(training_rate);
     
     if (success == false)
     {
         flext::error("unable to set training_rate, hint: should be between 0-1");
     }
 }
Exemplo n.º 8
0
 void ann::set_null_rejection_coeff(float null_rejection_coeff)
 {
     bool success = grt_ann.setNullRejectionCoeff(null_rejection_coeff);
     
     if (success == false)
     {
         flext::error("unable to set null_rejection_coeff, hint: should be greater than 0");
     }
 }
Exemplo n.º 9
0
 void ann::set_validation_set_size(int validation_set_size)
 {
     bool success = grt_ann.setValidationSetSize(validation_set_size);
     
     if (success == false)
     {
         flext::error("unable to set validation_set_size, hint: should be between 0-100");
     }
 }
Exemplo n.º 10
0
 void ann::set_randomise_training_order(bool randomise_training_order)
 {
     bool success = grt_ann.setRandomiseTrainingOrder(randomise_training_order);
     
     if (success == false)
     {
         flext::error("unable to set randomise_training_order, hint: should be 0 or 1");
     }
 }
Exemplo n.º 11
0
 void ann::set_min_epochs(int min_epochs)
 {
     bool success = grt_ann.setMinNumEpochs(min_epochs);
     
     if (success == false)
     {
         flext::error("unable to set min_epochs, hint: should be greater than 0");
     }
 }
Exemplo n.º 12
0
 void ann::set_min_change(float min_change)
 {
     bool success = grt_ann.setMinChange(min_change);
     
     if (success == false)
     {
         flext::error("unable to set min_change, hint: should be greater than 0");
     }
 }
Exemplo n.º 13
0
 void ann::set_use_validation_set(bool use_validation_set)
 {
     bool success = grt_ann.setUseValidationSet(use_validation_set);
     
     if (success == false)
     {
         flext::error("unable to set use_validation_set, hint: should be 0 or 1");
     }
 }
Exemplo n.º 14
0
 void ann::set_rand_training_iterations(int rand_training_iterations)
 {
     bool success = grt_ann.setNumRandomTrainingIterations(rand_training_iterations);
     
     if (success == false)
     {
         flext::error("unable to set rand_training_iterations, hint: should be greater than 0");
     }
 }
Exemplo n.º 15
0
 void ann::set_activation_function(int activation_function, ann_layer layer)
 {
     GRT::Neuron::Type activation_function_ = GRT::Neuron::Type::LINEAR;
     
     try
     {
         activation_function_ = get_grt_neuron_type(activation_function);
     }
     catch (std::exception& e)
     {
         flext::error(e.what());
         return;
     }
     
     if (grt_ann.validateActivationFunction(activation_function_) == false)
     {
         flext::error("activation function %d is invalid, hint should be between 0-%d", activation_function, GRT::Neuron::NUMBER_OF_ACTIVATION_FUNCTIONS - 1);
         return;
     }
     
     switch (layer)
     {
         case LAYER_INPUT:
             input_activation_function = activation_function_;
             break;
         case LAYER_HIDDEN:
             hidden_activation_function = activation_function_;
             break;
         case LAYER_OUTPUT:
             output_activation_function = activation_function_;
             break;
         default:
             ml::error("no activation function for layer: " + std::to_string(layer));
             return;
     }
     post("activation function set to " + grt_ann.activationFunctionToString(activation_function_));
 }
Exemplo n.º 16
0
 void ann::get_use_validation_set(bool &use_validation_set) const
 {
     use_validation_set = grt_ann.getUseValidationSet();
 }
Exemplo n.º 17
0
 void ann::get_rand_training_iterations(int &rand_training_iterations) const
 {
     rand_training_iterations = grt_ann.getNumRandomTrainingIterations();
 }
Exemplo n.º 18
0
 void ann::get_null_rejection_coeff(float &null_rejection_coeff) const
 {
     null_rejection_coeff = grt_ann.getNullRejectionCoeff();
 }
Exemplo n.º 19
0
 void ann::get_null_rejection(bool &null_rejection) const
 {
     null_rejection = grt_ann.getNullRejectionEnabled();
 }
Exemplo n.º 20
0
 void ann::get_gamma(float &gamma) const
 {
     gamma = grt_ann.getGamma();
 }
Exemplo n.º 21
0
 void ann::get_momentum(float &momentum) const
 {
     momentum = grt_ann.getMomentum();
 }
Exemplo n.º 22
0
 void ann::get_training_rate(float &training_rate) const
 {
     training_rate = grt_ann.getTrainingRate();
 }
Exemplo n.º 23
0
    void ann::map(int argc, const t_atom *argv)
    {
        const data_type data_type = get_data_type();

        GRT::UINT numSamples = data_type == LABELLED_CLASSIFICATION ? classification_data.getNumSamples() : regression_data.getNumSamples();

        if (numSamples == 0)
        {
            flext::error("no observations added, use 'add' to add training data");
            return;
        }

        if (grt_ann.getTrained() == false)
        {
            flext::error("model has not been trained, use 'train' to train the model");
            return;
        }
        
        GRT::UINT numInputNeurons = grt_ann.getNumInputNeurons();
        GRT::VectorDouble query(numInputNeurons);
        
        if (argc < 0 || (unsigned)argc != numInputNeurons)
        {
            flext::error("invalid input length, expected %d, got %d", numInputNeurons, argc);
        }

        for (uint32_t index = 0; index < (uint32_t)argc; ++index)
        {
            double value = GetAFloat(argv[index]);
            query[index] = value;
        }
        
        bool success = grt_ann.predict(query);
        
        if (success == false)
        {
            flext::error("unable to map input");
            return;
        }
        
        if (grt_ann.getClassificationModeActive())
        {
            const GRT::VectorDouble likelihoods = grt_ann.getClassLikelihoods();
            const GRT::Vector<GRT::UINT> labels = classification_data.getClassLabels();
            const GRT::UINT predicted = grt_ann.getPredictedClassLabel();
            const GRT::UINT classification = predicted == 0 ? 0 : get_class_id_for_index(predicted);
            
            if (likelihoods.size() != labels.size())
            {
                flext::error("labels / likelihoods size mismatch");
            }
            else if (probs)
            {
                AtomList probs_list;

                for (unsigned count = 0; count < labels.size(); ++count)
                {
                    t_atom label_a;
                    t_atom likelihood_a;
                    
                    SetFloat(likelihood_a, static_cast<float>(likelihoods[count]));
                    SetInt(label_a, get_class_id_for_index(labels[count]));
                    
                    probs_list.Append(label_a);
                    probs_list.Append(likelihood_a);
                }
                ToOutAnything(1, get_s_probs(), probs_list);
            }
                 
            ToOutInt(0, classification);
        }
        else if (grt_ann.getRegressionModeActive())
        {
            GRT::VectorDouble regression_data = grt_ann.getRegressionData();
            GRT::VectorDouble::size_type numOutputDimensions = regression_data.size();
            
            if (numOutputDimensions != grt_ann.getNumOutputNeurons())
            {
                flext::error("invalid output dimensions: %d", numOutputDimensions);
                return;
            }
            
            AtomList result;
            
            for (uint32_t index = 0; index < numOutputDimensions; ++index)
            {
                t_atom value_a;
                double value = regression_data[index];
                SetFloat(value_a, value);
                result.Append(value_a);
            }
            
            ToOutList(0, result);
        }
    }
Exemplo n.º 24
0
 void ann::get_validation_set_size(int &validation_set_size) const
 {
     validation_set_size = grt_ann.getValidationSetSize();
 }
Exemplo n.º 25
0
 void ann::get_randomise_training_order(bool &randomise_training_order) const
 {
     randomise_training_order = grt_ann.getRandomiseTrainingOrder();
 }
Exemplo n.º 26
0
 void mlp::clear()
 {
     grt_mlp.clear();
     ml::clear();
 }
Exemplo n.º 27
0
 void ann::clear()
 {
     grt_ann.clear();
     ml::clear();
     clear_index_maps();
 }
Exemplo n.º 28
0
 void ann::get_min_change(float &min_change) const
 {
     min_change = grt_ann.getMinChange();
 }
Exemplo n.º 29
0
 void ann::get_max_epochs(int &max_epochs) const
 {
     max_epochs = grt_ann.getMaxNumEpochs();
 }
Exemplo n.º 30
0
 void ann::set_max_epochs(int max_epochs)
 {
     grt_ann.setMaxNumEpochs(max_epochs);
 }