示例#1
0
/*! Given data (an image) and currently estimated parameters, find cluster
assignments.
@param img - image data. this is assumed to be three channel data. */
void GMM::getClusterAssignments(const vector<float>& img) {
	int n = img.size() / 3;
	if (assignment.size() != n)
		assignment = vector<int>(n, 0);

	// for each pixel estimate the assignment.
	vector<float> pix(3, 0.0f);
	vector<float> d;
	for (size_t i = 0; i < n; ++i) {
		pix[0] = img[3 * i];
		pix[1] = img[3 * i + 1];
		pix[2] = img[3 * i + 2];
		d = computeClusterDistance(pix);
		assignment[i] = min_element(d.begin(), d.end()) - d.begin();
	}
}
bool HierarchicalClustering::train_(MatrixFloat &data){
	
	trained = false;
    clusters.clear();
    distanceMatrix.clear();
    
    if( data.getNumRows() == 0 || data.getNumCols() == 0 ){
		return false;
	}
	
    //Set the rows and columns
    M = data.getNumRows();
	N = data.getNumCols();
    
    //Build the distance matrix
    distanceMatrix.resize(M,M);

    //Build the distance matrix
    for(UINT i=0; i<M; i++){
        for(UINT j=0; j<M; j++){
            if( i== j ) distanceMatrix[i][j] = grt_numeric_limits< Float >::max();
            else{
                distanceMatrix[i][j] = squaredEuclideanDistance(data[i], data[j]);
            }
        }
    }

    //Build the initial clusters, at the start each sample gets its own cluster
    UINT uniqueClusterID = 0;
    Vector< ClusterInfo > clusterData(M);
    for(UINT i=0; i<M; i++){
        clusterData[i].uniqueClusterID = uniqueClusterID++;
        clusterData[i].addSampleToCluster(i);
    }
    
    trainingLog << "Starting clustering..." << std::endl;
    
    //Create the first cluster level, each sample is it's own cluster
    UINT level = 0;
    ClusterLevel newLevel;
    newLevel.level = level;
    for(UINT i=0; i<M; i++){
        newLevel.clusters.push_back( clusterData[i] );
    }
    clusters.push_back( newLevel );
    
    //Move to level 1 and start the search
    level++;
    bool keepClustering = true;
    
    while( keepClustering ){
        
        //Find the closest two clusters within the cluster data
        Float minDist = grt_numeric_limits< Float >::max();
        Vector< Vector< UINT > > clusterPairs;
        UINT K = (UINT)clusterData.size();
        for(UINT i=0; i<K; i++){
            for(UINT j=0; j<K; j++){
                if( i != j ){
                    Float dist = computeClusterDistance( clusterData[i], clusterData[j]  );
             
                    if( dist < minDist ){
                        minDist = dist;
                        Vector< UINT > clusterPair(2);
                        clusterPair[0] = i;
                        clusterPair[1] = j;
                        clusterPairs.clear();
                        clusterPairs.push_back( clusterPair );
                    }
                    
                }
            }
        }
        
        if( minDist == grt_numeric_limits< Float >::max() ){
            keepClustering = false;
            warningLog << "train_(MatrixFloat &data) - Failed to find any cluster at level: " << level << std::endl;
            return false;
        }else{
        
            //Merge the two closest clusters together and create a new level
            ClusterLevel newLevel;
            newLevel.level = level;
            
            //Create the new cluster
            ClusterInfo newCluster;
            newCluster.uniqueClusterID = uniqueClusterID++;
            
            const UINT numClusterPairs = clusterPairs.getSize();
            
            for(UINT k=0; k<numClusterPairs; k++){
                //Add all the samples in the first cluster to the new cluster
                UINT numSamplesInClusterA = clusterData[ clusterPairs[k][0] ].getNumSamplesInCluster();
                for(UINT i=0; i<numSamplesInClusterA; i++){
                    UINT index = clusterData[ clusterPairs[k][0] ][ i ];
                    newCluster.addSampleToCluster( index );
                }
                
                //Add all the samples in the second cluster to the new cluster
                UINT numSamplesInClusterB = clusterData[ clusterPairs[k][1] ].getNumSamplesInCluster();
                for(UINT i=0; i<numSamplesInClusterB; i++){
                    UINT index = clusterData[ clusterPairs[k][1] ][ i ];
                    newCluster.addSampleToCluster( index );
                }
                
                //Compute the cluster variance
                newCluster.clusterVariance = computeClusterVariance( newCluster, data );
                
                //Remove the two cluster pairs (so they will not be used in the next search
                UINT idA = clusterData[ clusterPairs[k][0] ].getUniqueClusterID();
                UINT idB = clusterData[ clusterPairs[k][1] ].getUniqueClusterID();
                UINT numRemoved = 0;
                Vector< ClusterInfo >::iterator iter = clusterData.begin();
                while( iter != clusterData.end() ){
                    if( iter->getUniqueClusterID() == idA || iter->getUniqueClusterID() == idB ){
                        iter = clusterData.erase( iter );
                        if( ++numRemoved >= 2 ) break;
                    }else iter++;
                }
            }
            
            //Add the merged cluster to the clusterData
            clusterData.push_back( newCluster );
            
            //Add the new level and cluster data to the main cluster buffer
            newLevel.clusters.push_back( newCluster );
            
            clusters.push_back( newLevel );
            
            //Update the level
            level++;
        }
        
        //Check to see if we should stop clustering
        if( level >= M ){
            keepClustering = false;
        }
        
        if( clusterData.size() == 0 ){
            keepClustering = false;
        }
        
        trainingLog << "Cluster level: " << level << " Number of clusters: " << clusters.back().getNumClusters() << std::endl;
    }
    
    //Flag that the model is trained
    trained = true;
    
    //Setup the cluster labels
    clusterLabels.resize(numClusters);
    for(UINT i=0; i<numClusters; i++){
        clusterLabels[i] = i+1;
    }
    clusterLikelihoods.resize(numClusters,0);
    clusterDistances.resize(numClusters,0);

	return true;
}