// Methods // NOTE: ANN is special since it supports both regression and classification, we therefore override these methods void ann::train() { const data_type data_type = get_data_type(); GRT::UINT numSamples = data_type == LABELLED_CLASSIFICATION ? classification_data.getNumSamples() : regression_data.getNumSamples(); if (numSamples == 0) { flext::error("no observations added, use 'add' to add training data"); return; } bool success = false; if (data_type == LABELLED_CLASSIFICATION) { grt_ann.init( classification_data.getNumDimensions(), num_hidden_neurons, classification_data.getNumClasses(), input_activation_function, hidden_activation_function, output_activation_function ); success = grt_ann.train(classification_data); } else if (data_type == LABELLED_REGRESSION) { grt_ann.init( regression_data.getNumInputDimensions(), num_hidden_neurons, regression_data.getNumTargetDimensions(), input_activation_function, hidden_activation_function, output_activation_function ); success = grt_ann.train(regression_data); } if (!success) { flext::error("training failed"); } t_atom a_success; SetInt(a_success, success); ToOutAnything(1, get_s_train(), 1, &a_success); }