コード例 #1
0
ファイル: MyModel.cpp プロジェクト: hcmlab/mobileSSI
bool MyModel::train (ISamples &samples,
	ssi_size_t stream_index) {

	if (samples.getSize () == 0) {
		ssi_wrn ("empty sample list");
		return false;
	}

	if (isTrained ()) {
		ssi_wrn ("already trained");
		return false;
	}

	_n_classes = samples.getClassSize ();
	_n_features = samples.getStream (stream_index).dim;
	_centers = new ssi_real_t *[_n_classes];
	for (ssi_size_t i = 0; i < _n_classes; i++) {
		_centers[i] = new ssi_real_t[_n_features];
		for (ssi_size_t j = 0; j < _n_features; j++) {
			_centers[i][j] = 0;
		}
	}

	ssi_sample_t *sample;	
	samples.reset ();
	ssi_real_t *ptr = 0;
	while (sample = samples.next ()) {				
		ssi_size_t id = sample->class_id;	
		ptr = ssi_pcast (ssi_real_t, sample->streams[stream_index]->ptr);
		for (ssi_size_t j = 0; j < _n_features; j++) {
			_centers[id][j] += ptr[j];
		}
	}	 

	for (ssi_size_t i = 0; i < _n_classes; i++) {
		ssi_size_t num = samples.getSize (i);
		for (ssi_size_t j = 0; j < _n_features; j++) {
			_centers[i][j] /= num;
		}
	}

	return true;
}
コード例 #2
0
ファイル: SimpleKNN.cpp プロジェクト: hihiy/IntelliVoice
bool SimpleKNN::train (ISamples &samples,
	ssi_size_t stream_index) {

	if (samples.getSize () == 0) {
		ssi_wrn ("empty sample list");
		return false;
	}

	if (samples.getSize () < _options.k) {
		ssi_wrn ("sample list has less than '%u' entries", _options.k);
		return false;
	}

	if (isTrained ()) {
		ssi_wrn ("already trained");
		return false;
	}

	_n_classes = samples.getClassSize ();
	_n_samples = samples.getSize ();	
	_n_features = samples.getStream (stream_index).dim;
	_data = new ssi_real_t[_n_features*_n_samples];
	_classes = new ssi_size_t[_n_samples];

	ssi_sample_t *sample;	
	samples.reset ();
	ssi_real_t *data_ptr = _data;
	ssi_size_t *class_ptr = _classes;
	ssi_stream_t *stream_ptr = 0;
	ssi_size_t bytes_to_copy = _n_features * sizeof (ssi_real_t);
	while (sample = samples.next ()) {				
		memcpy (data_ptr, sample->streams[stream_index]->ptr, bytes_to_copy);
		*class_ptr++ = sample->class_id;
		data_ptr += _n_features;
	}	 

	return true;
}
コード例 #3
0
ファイル: Rank.cpp プロジェクト: hihiy/IntelliVoice
bool Rank::train (ISamples &samples,
	ssi_size_t stream_index) {

	if (!_model) {
		ssi_wrn ("a model has not been set yet");
		return false;
	}

	release ();

	_n_scores = samples.getStream (stream_index).dim;	
	_scores = new score[_n_scores];

	Evaluation eval;
	Trainer trainer (_model, stream_index);
	
	SSI_DBG (SSI_LOG_LEVEL_DEBUG, "evaluate dimensions:");
	for (ssi_size_t ndim = 0; ndim < _n_scores; ndim++) {
		ISSelectDim samples_s (&samples);
		samples_s.setSelection (stream_index, 1, &ndim);
		if (_options.loo) {
			eval.evalLOO (&trainer, samples_s);
		} else if (_options.louo) {
			eval.evalLOUO (&trainer, samples_s);
		} else {
			eval.evalKFold (&trainer, samples_s, _options.kfold);
		}
		_scores[ndim].index = ndim;
		_scores[ndim].value = eval.get_classwise_prob ();
		SSI_DBG (SSI_LOG_LEVEL_DEBUG, "  #%02u -> %.2f", _scores[ndim].index, _scores[ndim].value);
	} 

	trainer.release ();

	return true;
}
コード例 #4
0
ファイル: FileSamplesOut.cpp プロジェクト: hihiy/IntelliVoice
bool FileSamplesOut::open (ISamples &data,
	const ssi_char_t *path,
	File::TYPE type, 
	File::VERSION version) {

	ssi_msg (SSI_LOG_LEVEL_DETAIL, "open files '%s'", path);	

	_version = version;

	if (_version < File::V2) {
		ssi_wrn ("version < V2 not supported");
		return false;
	}

	if (_file_info || _file_data) {
		ssi_wrn ("samples already open");
		return false;
	}

	_n_users = data.getUserSize ();
	_users = new ssi_char_t *[_n_users];
	_n_per_user = new ssi_size_t[_n_users];
	for (ssi_size_t i = 0; i < _n_users; i++) {
		_users[i] = ssi_strcpy (data.getUserName (i));
		_n_per_user[i] = 0;
	}
	_n_classes = data.getClassSize ();
	_classes = new ssi_char_t *[_n_classes];
	_n_per_class = new ssi_size_t[_n_classes];
	for (ssi_size_t i = 0; i < _n_classes; i++) {
		_classes[i] = ssi_strcpy (data.getClassName (i));
		_n_per_class[i] = 0;
	}
	
	_n_streams = data.getStreamSize ();
	_streams = new ssi_stream_t[_n_streams];
	for (ssi_size_t i = 0; i < _n_streams; i++) {
		ssi_stream_t s = data.getStream (i);
		ssi_stream_init (_streams[i], 0, s.dim, s.byte, s.type, s.sr, 0);
	}

	_has_missing_data = false;

	if (path == 0 || path[0] == '\0') {
		_console = true;
	}

	if (_console) {
		
		_file_data = File::CreateAndOpen (type, File::WRITE, "");
		if (!_file_data) {
			ssi_wrn ("could not open console");
			return false;
		}

	} else {

		FilePath fp (path);
		ssi_char_t *path_info = 0;
		if (strcmp (fp.getExtension (), SSI_FILE_TYPE_SAMPLES) != 0) {
			path_info = ssi_strcat (path, SSI_FILE_TYPE_SAMPLES);
		} else {
			path_info = ssi_strcpy (path);
		}	
		_path = ssi_strcpy (path_info);

		_file_info = File::CreateAndOpen (File::ASCII, File::WRITE, path_info);
		if (!_file_info) {
			ssi_wrn ("could not open info file '%s'", path_info);
			return false;
		}

		ssi_sprint (_string, "<?xml version=\"1.0\" ?>\n<samples ssi-v=\"%d\">", version);
		_file_info->writeLine (_string);
	
		ssi_char_t *path_data = ssi_strcat (path_info, "~");			
		_file_data = File::CreateAndOpen (type, File::WRITE, path_data);
		if (!_file_data) {
			ssi_wrn ("could not open data file '%s'", path_data);
			return false;
		}

		if (_version == File::V3) {

			_file_streams = new FileStreamOut[_n_streams];
			ssi_char_t string[SSI_MAX_CHAR];
			for (ssi_size_t i = 0; i < _n_streams; i++) {
				ssi_sprint (string, "%s.#%u", path_info, i);
				_file_streams[i].open (_streams[i], string, type);				
			}

		}

		delete[] path_info;
		delete[] path_data;

	}

	return true;
};