Exemplo n.º 1
0
	double predict(Database* db, int roadIndex, const Item& item)
	{
		double jc = (double)db->getJamCount(roadIndex);
		double var = cluster->getVar(db, roadIndex);
		//double A = 480;
		//double B = 520;
		double A = 780;
		double B = 820;

		if ( midAvg == 0 )
		{
			midAvg = new KNN(db, LOWER_LIMIT, UPPER_LIMIT, cluster);
			midAvg->initialize(db);
		}

		if ( smallAvg == 0 )
		{
			smallAvg = new SmallAverage(db, cluster, LOWER_LIMIT);
			smallAvg->initialize(db);
		}

/*		if ( jc >= 30 && var >= 25.0 )
		{
			return midAvg->predict(db, roadIndex, item);
		}*/
		
		if ( jc > B )
		{
			return midAvg->predict(db, roadIndex, item);
		}

//		return 50.0;
		if ( jc < A )
		{
			return smallAvg->predict(db, roadIndex, item);
		}

		double mid = midAvg->predict(db, roadIndex, item);
		double small = smallAvg->predict(db, roadIndex, item);

		double K = (jc - A)/(B - A);
		double res = small*(1.0 - K) + mid*K;

		return res;
	}
Exemplo n.º 2
0
// Tests the learning algorithm on a basic dataset
TEST(KNN, TrainBasicDataset) {
  
  KNN knn;

  //Check the module is not trained
  EXPECT_TRUE( !knn.getTrained() );

  //Generate a basic dataset
  const UINT numSamples = 1000;
  const UINT numClasses = 10;
  const UINT numDimensions = 10;
  ClassificationData::generateGaussDataset( "gauss_data.csv", numSamples, numClasses, numDimensions, 10, 1 );
  ClassificationData trainingData;
  EXPECT_TRUE( trainingData.load( "gauss_data.csv" ) );

  ClassificationData testData = trainingData.split( 50 );

  //Train the classifier
  EXPECT_TRUE( knn.train( trainingData ) );

  EXPECT_TRUE( knn.getTrained() );

  EXPECT_TRUE( knn.print() );

  for(UINT i=0; i<testData.getNumSamples(); i++){
    EXPECT_TRUE( knn.predict( testData[i].getSample() ) );
  }

  EXPECT_TRUE( knn.save( "knn_model.grt" ) );

  knn.clear();

  EXPECT_TRUE( !knn.getTrained() );

  EXPECT_TRUE( knn.load( "knn_model.grt" ) );

  EXPECT_TRUE( knn.getTrained() );

  for(UINT i=0; i<testData.getNumSamples(); i++){
    EXPECT_TRUE( knn.predict( testData[i].getSample() ) );
  }


}