bool MyModel::train (ISamples &samples, ssi_size_t stream_index) { if (samples.getSize () == 0) { ssi_wrn ("empty sample list"); return false; } if (isTrained ()) { ssi_wrn ("already trained"); return false; } _n_classes = samples.getClassSize (); _n_features = samples.getStream (stream_index).dim; _centers = new ssi_real_t *[_n_classes]; for (ssi_size_t i = 0; i < _n_classes; i++) { _centers[i] = new ssi_real_t[_n_features]; for (ssi_size_t j = 0; j < _n_features; j++) { _centers[i][j] = 0; } } ssi_sample_t *sample; samples.reset (); ssi_real_t *ptr = 0; while (sample = samples.next ()) { ssi_size_t id = sample->class_id; ptr = ssi_pcast (ssi_real_t, sample->streams[stream_index]->ptr); for (ssi_size_t j = 0; j < _n_features; j++) { _centers[id][j] += ptr[j]; } } for (ssi_size_t i = 0; i < _n_classes; i++) { ssi_size_t num = samples.getSize (i); for (ssi_size_t j = 0; j < _n_features; j++) { _centers[i][j] /= num; } } return true; }
bool SimpleKNN::train (ISamples &samples, ssi_size_t stream_index) { if (samples.getSize () == 0) { ssi_wrn ("empty sample list"); return false; } if (samples.getSize () < _options.k) { ssi_wrn ("sample list has less than '%u' entries", _options.k); return false; } if (isTrained ()) { ssi_wrn ("already trained"); return false; } _n_classes = samples.getClassSize (); _n_samples = samples.getSize (); _n_features = samples.getStream (stream_index).dim; _data = new ssi_real_t[_n_features*_n_samples]; _classes = new ssi_size_t[_n_samples]; ssi_sample_t *sample; samples.reset (); ssi_real_t *data_ptr = _data; ssi_size_t *class_ptr = _classes; ssi_stream_t *stream_ptr = 0; ssi_size_t bytes_to_copy = _n_features * sizeof (ssi_real_t); while (sample = samples.next ()) { memcpy (data_ptr, sample->streams[stream_index]->ptr, bytes_to_copy); *class_ptr++ = sample->class_id; data_ptr += _n_features; } return true; }
bool Rank::train (ISamples &samples, ssi_size_t stream_index) { if (!_model) { ssi_wrn ("a model has not been set yet"); return false; } release (); _n_scores = samples.getStream (stream_index).dim; _scores = new score[_n_scores]; Evaluation eval; Trainer trainer (_model, stream_index); SSI_DBG (SSI_LOG_LEVEL_DEBUG, "evaluate dimensions:"); for (ssi_size_t ndim = 0; ndim < _n_scores; ndim++) { ISSelectDim samples_s (&samples); samples_s.setSelection (stream_index, 1, &ndim); if (_options.loo) { eval.evalLOO (&trainer, samples_s); } else if (_options.louo) { eval.evalLOUO (&trainer, samples_s); } else { eval.evalKFold (&trainer, samples_s, _options.kfold); } _scores[ndim].index = ndim; _scores[ndim].value = eval.get_classwise_prob (); SSI_DBG (SSI_LOG_LEVEL_DEBUG, " #%02u -> %.2f", _scores[ndim].index, _scores[ndim].value); } trainer.release (); return true; }
bool FileSamplesOut::open (ISamples &data, const ssi_char_t *path, File::TYPE type, File::VERSION version) { ssi_msg (SSI_LOG_LEVEL_DETAIL, "open files '%s'", path); _version = version; if (_version < File::V2) { ssi_wrn ("version < V2 not supported"); return false; } if (_file_info || _file_data) { ssi_wrn ("samples already open"); return false; } _n_users = data.getUserSize (); _users = new ssi_char_t *[_n_users]; _n_per_user = new ssi_size_t[_n_users]; for (ssi_size_t i = 0; i < _n_users; i++) { _users[i] = ssi_strcpy (data.getUserName (i)); _n_per_user[i] = 0; } _n_classes = data.getClassSize (); _classes = new ssi_char_t *[_n_classes]; _n_per_class = new ssi_size_t[_n_classes]; for (ssi_size_t i = 0; i < _n_classes; i++) { _classes[i] = ssi_strcpy (data.getClassName (i)); _n_per_class[i] = 0; } _n_streams = data.getStreamSize (); _streams = new ssi_stream_t[_n_streams]; for (ssi_size_t i = 0; i < _n_streams; i++) { ssi_stream_t s = data.getStream (i); ssi_stream_init (_streams[i], 0, s.dim, s.byte, s.type, s.sr, 0); } _has_missing_data = false; if (path == 0 || path[0] == '\0') { _console = true; } if (_console) { _file_data = File::CreateAndOpen (type, File::WRITE, ""); if (!_file_data) { ssi_wrn ("could not open console"); return false; } } else { FilePath fp (path); ssi_char_t *path_info = 0; if (strcmp (fp.getExtension (), SSI_FILE_TYPE_SAMPLES) != 0) { path_info = ssi_strcat (path, SSI_FILE_TYPE_SAMPLES); } else { path_info = ssi_strcpy (path); } _path = ssi_strcpy (path_info); _file_info = File::CreateAndOpen (File::ASCII, File::WRITE, path_info); if (!_file_info) { ssi_wrn ("could not open info file '%s'", path_info); return false; } ssi_sprint (_string, "<?xml version=\"1.0\" ?>\n<samples ssi-v=\"%d\">", version); _file_info->writeLine (_string); ssi_char_t *path_data = ssi_strcat (path_info, "~"); _file_data = File::CreateAndOpen (type, File::WRITE, path_data); if (!_file_data) { ssi_wrn ("could not open data file '%s'", path_data); return false; } if (_version == File::V3) { _file_streams = new FileStreamOut[_n_streams]; ssi_char_t string[SSI_MAX_CHAR]; for (ssi_size_t i = 0; i < _n_streams; i++) { ssi_sprint (string, "%s.#%u", path_info, i); _file_streams[i].open (_streams[i], string, type); } } delete[] path_info; delete[] path_data; } return true; };