bool MajorityVoting::train (ssi_size_t n_models, IModel **models, ISamples &samples) { if (samples.getSize () == 0) { ssi_wrn ("empty sample list"); return false; } if (samples.getStreamSize () != n_models) { ssi_wrn ("#models (%u) differs from #streams (%u)", n_models, samples.getStreamSize ()); return false; } if (isTrained ()) { ssi_wrn ("already trained"); return false; } _n_streams = samples.getStreamSize (); _n_classes = samples.getClassSize (); _n_models = n_models; if (samples.hasMissingData ()) { ISMissingData samples_h (&samples); for (ssi_size_t n_model = 0; n_model < n_models; n_model++) { if (!models[n_model]->isTrained ()) { samples_h.setStream (n_model); models[n_model]->train (samples_h, n_model); } } } else{ for (ssi_size_t n_model = 0; n_model < n_models; n_model++) { if (!models[n_model]->isTrained ()) { models[n_model]->train (samples, n_model); } } } return true; }
bool SimpleFusion::train (ssi_size_t n_models, IModel **models, ISamples &samples) { if (samples.getSize () == 0) { ssi_wrn ("empty sample list"); return false; } if (isTrained ()) { ssi_wrn ("already trained"); return false; } ssi_size_t n_streams = samples.getStreamSize (); if (n_streams != 1 && n_streams != n_models) { ssi_err ("#models (%u) differs from #streams (%u)", n_models, n_streams); } if (samples.hasMissingData ()) { ISMissingData samples_h (&samples); for (ssi_size_t n_model = 0; n_model < n_models; n_model++) { if (!models[n_model]->isTrained ()) { samples_h.setStream(n_streams == 1 ? 0 : n_model); models[n_model]->train (samples_h, n_model); } } } else { for (ssi_size_t n_model = 0; n_model < n_models; n_model++) { if (!models[n_model]->isTrained ()) { models[n_model]->train(samples, n_streams == 1 ? 0 : n_model); } } } _is_trained = true; return true; }
bool FeatureFusion::train (ssi_size_t n_models, IModel **models, ISamples &samples) { if (samples.getSize () == 0) { ssi_wrn ("empty sample list"); return false; } if (isTrained ()) { ssi_wrn ("already trained"); return false; } _n_streams = samples.getStreamSize (); _n_classes = samples.getClassSize (); _n_models = n_models; //initialize weights ssi_real_t **weights = new ssi_real_t*[n_models]; for (ssi_size_t n_model = 0; n_model < n_models; n_model++) { weights[n_model] = new ssi_real_t[_n_classes+1]; } if (samples.hasMissingData ()) { _handle_md = true; ISMissingData samples_h (&samples); Evaluation eval; if (ssi_log_level >= SSI_LOG_LEVEL_DEBUG) { ssi_print("\nMissing data detected.\n"); } //models[0] is featfuse_model, followed by singlechannel_models ISMergeDim ffusionSamples (&samples); ISMissingData ffusionSamples_h (&ffusionSamples); ffusionSamples_h.setStream(0); if (!models[0]->isTrained ()) { models[0]->train (ffusionSamples_h, 0); } if (ssi_log_level >= SSI_LOG_LEVEL_DEBUG) { eval.eval (*models[0], ffusionSamples_h, 0); eval.print(); } //dummy weights for fused model for (ssi_size_t n_class = 0; n_class < _n_classes; n_class++) { weights[0][n_class] = 0.0f; } weights[0][_n_classes] = 0.0f; for (ssi_size_t n_model = 1; n_model < n_models; n_model++) { if (!models[n_model]->isTrained ()) { samples_h.setStream (n_model - 1); models[n_model]->train (samples_h, n_model - 1); } eval.eval (*models[n_model], samples_h, n_model - 1); if (ssi_log_level >= SSI_LOG_LEVEL_DEBUG) { eval.print(); } for (ssi_size_t n_class = 0; n_class < _n_classes; n_class++) { weights[n_model][n_class] = eval.get_class_prob (n_class); } weights[n_model][_n_classes] = eval.get_classwise_prob (); } //calculate fillers _filler = new ssi_size_t[_n_streams]; for (ssi_size_t n_fill = 0; n_fill < _n_streams; n_fill++) { _filler[n_fill] = 1; ssi_real_t filler_weight = weights[1][_n_classes]; for (ssi_size_t n_model = 2; n_model < n_models; n_model++) { if (filler_weight < weights[n_model][_n_classes]) { _filler[n_fill] = n_model; filler_weight = weights[n_model][_n_classes]; } } weights[_filler[n_fill]][_n_classes] = 0.0f; } if (ssi_log_level >= SSI_LOG_LEVEL_DEBUG) { ssi_print("\nfiller:\n"); for (ssi_size_t n_model = 0; n_model < _n_streams; n_model++) { ssi_print("%d ", _filler[n_model]); }ssi_print("\n"); } } else{ _handle_md = false; if (ssi_log_level >= SSI_LOG_LEVEL_DEBUG) { ssi_print("\nNo missing data detected.\n"); } ISMergeDim ffusionSamples (&samples); if (!models[0]->isTrained ()) { models[0]->train (ffusionSamples, 0); } //dummy _filler = new ssi_size_t[_n_streams]; for (ssi_size_t n_fill = 0; n_fill < _n_streams; n_fill++) { _filler[n_fill] = 0; } } if (weights) { for (ssi_size_t n_model = 0; n_model < _n_models; n_model++) { delete[] weights[n_model]; } delete[] weights; weights = 0; } return true; }
bool WeightedMajorityVoting::train (ssi_size_t n_models, IModel **models, ISamples &samples) { if (samples.getSize () == 0) { ssi_wrn ("empty sample list"); return false; } if (samples.getStreamSize () != n_models) { ssi_wrn ("#models (%u) differs from #streams (%u)", n_models, samples.getStreamSize ()); return false; } if (isTrained ()) { ssi_wrn ("already trained"); return false; } _n_streams = samples.getStreamSize (); _n_classes = samples.getClassSize (); _n_models = n_models; _weights = new ssi_real_t*[n_models]; for (ssi_size_t n_model = 0; n_model < n_models; n_model++) { _weights[n_model] = new ssi_real_t[_n_classes+1]; } if (samples.hasMissingData ()) { ISMissingData samples_h (&samples); Evaluation eval; for (ssi_size_t n_model = 0; n_model < n_models; n_model++) { if (!models[n_model]->isTrained ()) { samples_h.setStream (n_model); models[n_model]->train (samples_h, n_model); } eval.eval (*models[n_model], samples_h, n_model); for (ssi_size_t n_class = 0; n_class < _n_classes; n_class++) { _weights[n_model][n_class] = eval.get_class_prob (n_class); } _weights[n_model][_n_classes] = eval.get_classwise_prob (); } } else{ Evaluation eval; for (ssi_size_t n_model = 0; n_model < n_models; n_model++) { if (!models[n_model]->isTrained ()) { models[n_model]->train (samples, n_model); } eval.eval (*models[n_model], samples, n_model); for (ssi_size_t n_class = 0; n_class < _n_classes; n_class++) { _weights[n_model][n_class] = eval.get_class_prob (n_class); } _weights[n_model][_n_classes] = eval.get_classwise_prob (); } } if (ssi_log_level >= SSI_LOG_LEVEL_DEBUG) { ssi_print("\nClassifier Weights: \n"); for (ssi_size_t n_model = 0; n_model < n_models; n_model++) { for (ssi_size_t n_class = 0; n_class < _n_classes; n_class++) { ssi_print ("%f ", _weights[n_model][n_class]); } ssi_print ("%f\n", _weights[n_model][_n_classes]); } } return true; }