Esempio n. 1
0
    // Methods
    // NOTE: ANN is special since it supports both regression and classification, we therefore override these methods
    void ann::train()
    {
        const data_type data_type = get_data_type();
        
        GRT::UINT numSamples = data_type == LABELLED_CLASSIFICATION ? classification_data.getNumSamples() : regression_data.getNumSamples();
        
        if (numSamples == 0)
        {
            flext::error("no observations added, use 'add' to add training data");
            return;
        }
        
        bool success = false;
        
        if (data_type == LABELLED_CLASSIFICATION)
        {
            grt_ann.init(
                     classification_data.getNumDimensions(),
                     num_hidden_neurons,
                     classification_data.getNumClasses(),
                     input_activation_function,
                     hidden_activation_function,
                     output_activation_function
                     );
            success = grt_ann.train(classification_data);
        }
        else if (data_type == LABELLED_REGRESSION)
        {
            grt_ann.init(
                     regression_data.getNumInputDimensions(),
                     num_hidden_neurons,
                     regression_data.getNumTargetDimensions(),
                     input_activation_function,
                     hidden_activation_function,
                     output_activation_function
                     );
            success = grt_ann.train(regression_data);
        }

        if (!success)
        {
            flext::error("training failed");
        }
        
        t_atom a_success;
        
        SetInt(a_success, success);
        ToOutAnything(1, get_s_train(), 1, &a_success);
    }