bool train( const Ptr<TrainData>& trainData, int flags ) { const int MAX_ITER = 1000; const double DEFAULT_EPSILON = FLT_EPSILON; // initialize training data Mat inputs = trainData->getTrainSamples(); Mat outputs = trainData->getTrainResponses(); Mat sw = trainData->getTrainSampleWeights(); prepare_to_train( inputs, outputs, sw, flags ); // ... and link weights if( !(flags & UPDATE_WEIGHTS) ) init_weights(); TermCriteria termcrit; termcrit.type = TermCriteria::COUNT + TermCriteria::EPS; termcrit.maxCount = std::max((params.termCrit.type & CV_TERMCRIT_ITER ? params.termCrit.maxCount : MAX_ITER), 1); termcrit.epsilon = std::max((params.termCrit.type & CV_TERMCRIT_EPS ? params.termCrit.epsilon : DEFAULT_EPSILON), DBL_EPSILON); int iter = params.trainMethod == ANN_MLP::BACKPROP ? train_backprop( inputs, outputs, sw, termcrit ) : train_rprop( inputs, outputs, sw, termcrit ); trained = iter > 0; return trained; }
int CvANN_MLP::train( const CvMat* _inputs, const CvMat* _outputs, const CvMat* _sample_weights, const CvMat* _sample_idx, CvANN_MLP_TrainParams _params, int flags ) { const int MAX_ITER = 1000; const double DEFAULT_EPSILON = FLT_EPSILON; double* sw = 0; CvVectors x0, u; int iter = -1; x0.data.ptr = u.data.ptr = 0; CV_FUNCNAME( "CvANN_MLP::train" ); __BEGIN__; int max_iter; double epsilon; params = _params; // initialize training data CV_CALL( prepare_to_train( _inputs, _outputs, _sample_weights, _sample_idx, &x0, &u, &sw, flags )); // ... and link weights if( !(flags & UPDATE_WEIGHTS) ) init_weights(); max_iter = params.term_crit.type & CV_TERMCRIT_ITER ? params.term_crit.max_iter : MAX_ITER; max_iter = MIN( max_iter, MAX_ITER ); max_iter = MAX( max_iter, 1 ); epsilon = params.term_crit.type & CV_TERMCRIT_EPS ? params.term_crit.epsilon : DEFAULT_EPSILON; epsilon = MAX(epsilon, DBL_EPSILON); params.term_crit.type = CV_TERMCRIT_ITER + CV_TERMCRIT_EPS; params.term_crit.max_iter = max_iter; params.term_crit.epsilon = epsilon; if( params.train_method == CvANN_MLP_TrainParams::BACKPROP ) { CV_CALL( iter = train_backprop( x0, u, sw )); } else { CV_CALL( iter = train_rprop( x0, u, sw )); } __END__; cvFree( &x0.data.ptr ); cvFree( &u.data.ptr ); cvFree( &sw ); return iter; }