Пример #1
0
void SelectorTest::testSelect(TestDataSuite* tds)
{
    while (tds->hasMoreTestData()) {
	std::cerr << "Updating strstream: " << strstream->str() << '|' << std::endl
		  << " with: " << tds->getUpdateString() << '|' << std::endl;
	*strstream << tds->getUpdateString();
	std::vector<std::string> selectedTokens;
	selectedTokens = selector->select(tds->getInputPrediction());

	Prediction expected = tds->getOutputPrediction();
	CPPUNIT_ASSERT_EQUAL( (size_t)expected.size(), selectedTokens.size() );

	std::vector<std::string>::const_iterator actual_it = selectedTokens.begin();
	int index = 0;
	while (actual_it != selectedTokens.end()) {
	    std::cerr << "[expected] " << expected.getSuggestion(index).getWord()
		      << "  [actual] " << *actual_it << std::endl;
	    CPPUNIT_ASSERT_EQUAL(expected.getSuggestion(index).getWord(),
				 *actual_it);
	    
	    index++;
	    actual_it++;
	}

	contextTracker->update();
	tds->nextTestData();
    }
}
void PredictorRegistryTest::testNext()
{
    ContextTracker* pointer = static_cast<ContextTracker*>((void*)0xdeadbeef);
    registry->setContextTracker(pointer);

    PredictorRegistry::Iterator it = registry->iterator();
    Predictor* predictor = 0;

    while (it.hasNext()) {
	predictor = it.next();
    }

    // since we've iterated till the end of the predictors list, predictor
    // is now pointing to the DummyPredictor, so let's test we got the
    // dummy prediction back
    Prediction prediction = predictor->predict(20, 0);

    CPPUNIT_ASSERT(predictor != 0);
    size_t expected_size = 18;
    CPPUNIT_ASSERT_EQUAL(expected_size, prediction.size());
    CPPUNIT_ASSERT_EQUAL(Suggestion("foo1", 0.99), prediction.getSuggestion(0));
    CPPUNIT_ASSERT_EQUAL(Suggestion("foobar6", 0.74), prediction.getSuggestion(17));
}
Пример #3
0
bool Prediction::operator== (const Prediction& right) const
{
    // same instance is obviously equal to itself
    if (&right == this) {
	return true;
    } else {
	if (size() != right.size()) {
	    return false;
	} else {
	    // need to compare each suggestion
	    bool result = true;
	    size_t i = 0;
	    while (i < size() && result) {
		if (getSuggestion(i) != right.getSuggestion(i)) {
		    result = false;
		}
		i++;
	    }
	    return result;
	}
    }
}
void NewSmoothedNgramPluginTest::testLearning()
{
    // get pointer to plugin
    Plugin* plugin = pluginRegistry->iterator().next();

    {
	*stream << "f";
	ct->update();
	Prediction actual = plugin->predict(SIZE, 0);
	CPPUNIT_ASSERT_EQUAL(static_cast<size_t>(0), actual.size());
    }

    {
	*stream << "o";
	ct->update();
	Prediction actual = plugin->predict(SIZE, 0);
	CPPUNIT_ASSERT_EQUAL(static_cast<size_t>(0), actual.size());
    }

    {
	*stream << "o ";
	ct->update();
	Prediction actual = plugin->predict(SIZE, 0);
	CPPUNIT_ASSERT_EQUAL(static_cast<size_t>(1), actual.size());
	CPPUNIT_ASSERT_EQUAL(std::string("foo"), actual.getSuggestion(0).getWord());
	ct->update();
    }

    {
	*stream << "bar";
	ct->update();
	Prediction actual = plugin->predict(SIZE, 0);
    }

    {
	*stream << " ";
	ct->update();
	Prediction actual = plugin->predict(SIZE, 0);
	CPPUNIT_ASSERT_EQUAL(static_cast<size_t>(2), actual.size());
	CPPUNIT_ASSERT_EQUAL(std::string("foo"), actual.getSuggestion(0).getWord());
	CPPUNIT_ASSERT_EQUAL(std::string("bar"), actual.getSuggestion(1).getWord());
    }

    {
	*stream << "foobar ";
	ct->update();
	Prediction actual = plugin->predict(SIZE, 0);
	CPPUNIT_ASSERT_EQUAL(static_cast<size_t>(3), actual.size());
	CPPUNIT_ASSERT_EQUAL(std::string("foobar"), actual.getSuggestion(0).getWord());
	CPPUNIT_ASSERT_EQUAL(std::string("foo"), actual.getSuggestion(1).getWord());
	CPPUNIT_ASSERT_EQUAL(std::string("bar"), actual.getSuggestion(2).getWord());
    }

    {
	*stream << "f";
	ct->update();
	Prediction actual = plugin->predict(SIZE, 0);
	CPPUNIT_ASSERT_EQUAL(static_cast<size_t>(2), actual.size());
	CPPUNIT_ASSERT_EQUAL(std::string("foobar"), actual.getSuggestion(0).getWord());
	CPPUNIT_ASSERT_EQUAL(std::string("foo"), actual.getSuggestion(1).getWord());
    }
}
Prediction SmoothedCountPlugin::predict(const size_t max_partial_predictions_size, const char** filter) const
{
    // get w_2, w_1, and prefix from HistoryTracker object
    std::string prefix = strtolower( contextTracker->getPrefix() );
    std::string word_1 = strtolower( contextTracker->getToken(1) );
    std::string word_2 = strtolower( contextTracker->getToken(2) );
    
    std::string query; // string used to build sql query
    int result;        // database interrogation diagnostic
    CallbackData data; // data to pass through to callback function
	

    // get most likely unigrams whose w contains prefix
    Prediction predUnigrams;
    
    data.predPtr = &predUnigrams;
    data.predSize = MAX_PARTIAL_PREDICTION_SIZE;
    
    query = 
	"SELECT word, count "
	"FROM _1_gram "
	"WHERE word LIKE \"" + prefix + "%\" "
	"ORDER BY count DESC;";
    
#if defined(HAVE_SQLITE3_H)
    result = sqlite3_exec(
#elif defined(HAVE_SQLITE_H)
    result = sqlite_exec(
#endif
	db,
	query.c_str(),
	buildPrediction,
	&data,
	NULL
    );
    assert(result == SQLITE_OK);


    // get most likely bigrams having matching w_1 whose w contains prefix
    Prediction predBigrams;
    
    data.predPtr = &predBigrams;
    
    query = 
    "SELECT word, count "
    "FROM _2_gram "
    "WHERE word_1 = \"" + word_1 + "\" "
    "AND word LIKE \"" + prefix + "\" "
    "ORDER BY count DESC;";
    
#if defined(HAVE_SQLITE3_H)
    result = sqlite3_exec(
#elif defined(HAVE_SQLITE_H)
    result = sqlite_exec(
#endif
	db,
	query.c_str(),
	buildPrediction,
	&data,
	NULL
    );
    assert(result == SQLITE_OK);


    // get most likely trigrams having matching w_2, w_1 whose w contains prefix
    Prediction predTrigrams;
    
    data.predPtr = &predTrigrams;
    
    query = 
    "SELECT word, count "
    "FROM _3_gram "
    "WHERE word_2 = \"" + word_2 + "\" "
    "AND word_1 = \"" + word_1 + "\" "
    "AND word LIKE \"" + prefix + "\" "
    "ORDER BY count DESC;";
    
#if defined(HAVE_SQLITE3_H)
    result = sqlite3_exec(
#elif defined(HAVE_SQLITE_H)
    result = sqlite_exec(
#endif
	db,
	query.c_str(),
	buildPrediction,
	&data,
	NULL
    );
    assert(result == SQLITE_OK);
	
    
    Prediction p;     // combined result of uni/bi/tri gram predictions
    std::string word; // pivot unigram word (used in next for loop)
    double ccount;    // combined count
    
    // compute smoothed probability estimation
    
    // TODO !!!!!!!! Everything should be scaled down to probabilities!!!
    // TODO That means that counts should be scaled down to values between
    // TODO 0 and 1. We need total word count to do that.
    
    // TODO : after correct word has been found in inner loops, execution
    // TODO : can break out of it.
    for (size_t i = 0; i < predUnigrams.size(); i++) {

	word   = predUnigrams.getSuggestion( i ).getWord();
	ccount = unigram_weight *
	    predUnigrams.getSuggestion( i ).getProbability();
	
	for (size_t j = 0; j < predBigrams.size(); j++) {

	    if( predBigrams.getSuggestion(j).getWord() == word ) {
		
		for (size_t k = 0; k < predTrigrams.size(); k++ ) {
		    
		    if( predTrigrams.getSuggestion(k).getWord() == word ) {
			
			ccount += trigram_weight *
			    predTrigrams.getSuggestion(k).getProbability();
			
		    }
		}
		
		ccount += bigram_weight *
		    predBigrams.getSuggestion(j).getProbability();
		
	    }
	    
	}
	
	p.addSuggestion( Suggestion( word, ccount ) );
	
    }

    return p; // Return combined prediction
}