示例#1
0
bool CGaussian::train(CFeatures* data)
{
	// init features with data if necessary and assure type is correct
	if (data)
	{
		if (!data->has_property(FP_DOT))
				SG_ERROR("Specified features are not of type CDotFeatures\n");		
		set_features(data);
	}
	CDotFeatures* dotdata = (CDotFeatures *) data;

	delete[] m_mean;
	delete[] m_cov;

	dotdata->get_mean(&m_mean, &m_mean_length);
	dotdata->get_cov(&m_cov, &m_cov_rows, &m_cov_cols);

	init();

	return true;
}
示例#2
0
文件: GMM.cpp 项目: AsherBond/shogun
bool CGMM::train(CFeatures* data)
{
	ASSERT(m_n != 0);
	if (m_components)
		cleanup();

	/** init features with data if necessary and assure type is correct */
	if (data)
	{
		if (!data->has_property(FP_DOT))
				SG_ERROR("Specified features are not of type CDotFeatures\n");		
		set_features(data);
	}

	CDotFeatures* dotdata = (CDotFeatures *) data;
	int32_t num_vectors = dotdata->get_num_vectors();
	int32_t num_dim = dotdata->get_dim_feature_space();

	CEuclidianDistance* dist = new CEuclidianDistance();
	CKMeans* init_k_means = new CKMeans(m_n, dist);
	init_k_means->train(dotdata);
	float64_t* init_means;
	int32_t init_mean_dim;
	int32_t init_mean_size;
	init_k_means->get_cluster_centers(&init_means, &init_mean_dim, &init_mean_size);

	float64_t* init_cov;
	int32_t init_cov_rows;
	int32_t init_cov_cols;
	dotdata->get_cov(&init_cov, &init_cov_rows, &init_cov_cols);

	m_coefficients = new float64_t[m_coef_size];
	m_components = new CGaussian*[m_n];

	for (int i=0; i<m_n; i++)
	{
		m_coefficients[i] = 1.0/m_coef_size;
		m_components[i] = new CGaussian(&(init_means[i*init_mean_dim]), init_mean_dim,
								init_cov, init_cov_rows, init_cov_cols);
	}

	/** question of faster vs. less memory using */
	float64_t* pdfs = new float64_t[num_vectors*m_n];
	float64_t* T = new float64_t[num_vectors*m_n];
	int32_t iter = 0;
	float64_t e_log_likelihood_change = m_minimal_change + 1;
	float64_t e_log_likelihood_old = 0;
	float64_t e_log_likelihood_new = -FLT_MAX;

	while (iter<m_max_iter && e_log_likelihood_change>m_minimal_change)
	{
		e_log_likelihood_old = e_log_likelihood_new;
		e_log_likelihood_new = 0;

		/** Precomputing likelihoods */
		float64_t* point;
		int32_t point_len;

		for (int i=0; i<num_vectors; i++)
		{
			dotdata->get_feature_vector(&point, &point_len, i);
			for (int j=0; j<m_n; j++)
				pdfs[i*m_n+j] = m_components[j]->compute_PDF(point, point_len);
			delete[] point;
		}

		for (int i=0; i<num_vectors; i++)
		{
			float64_t sum = 0;

			for (int j=0; j<m_n; j++)
				sum += m_coefficients[j]*pdfs[i*m_n+j];

			for (int j=0; j<m_n; j++)
			{
				T[i*m_n+j] = (m_coefficients[j]*pdfs[i*m_n+j])/sum;
				e_log_likelihood_new += T[i*m_n+j]*CMath::log(m_coefficients[j]*pdfs[i*m_n+j]);
			}
		}

		/** Not sure if getting the abs value is a good idea */
		e_log_likelihood_change = CMath::abs(e_log_likelihood_new - e_log_likelihood_old);

		/** Updates */
		float64_t T_sum;
		float64_t* mean_sum;
		float64_t* cov_sum;

		for (int i=0; i<m_n; i++)
		{
			T_sum = 0;
			mean_sum = new float64_t[num_dim];
			memset(mean_sum, 0, num_dim*sizeof(float64_t));

			for (int j=0; j<num_vectors; j++)
			{
				T_sum += T[j*m_n+i];
				dotdata->get_feature_vector(&point, &point_len, j);
				CMath::add<float64_t>(mean_sum, T[j*m_n+i], point, 1, mean_sum, point_len);
				delete[] point;
			}

			m_coefficients[i] = T_sum/num_vectors;

			for (int j=0; j<num_dim; j++)
				mean_sum[j] /= T_sum;
			
			m_components[i]->set_mean(mean_sum, num_dim);

			cov_sum = new float64_t[num_dim*num_dim];
			memset(cov_sum, 0, num_dim*num_dim*sizeof(float64_t));

			for (int j=0; j<num_vectors; j++)
			{
				dotdata->get_feature_vector(&point, &point_len, j);	
				CMath::add<float64_t>(point, 1, point, -1, mean_sum, point_len);
				cblas_dger(CblasRowMajor, num_dim, num_dim, T[j*m_n+i], point, 1, point,
                    1, (double*) cov_sum, num_dim);
				delete[] point;
			}

			for (int j=0; j<num_dim*num_dim; j++)
				cov_sum[j] /= T_sum;

			m_components[i]->set_cov(cov_sum, num_dim, num_dim);

			delete[] mean_sum;
			delete[] cov_sum;
		}
		iter++;
	}

	delete[] pdfs;
	delete[] T;
	return true;
}