Exemplo n.º 1
0
Prediction DummyPlugin::predict(const size_t max_partial_predictions_size, const char** filter) const
{
    // A real plugin would query its resources to retrieve the most 
    // probable completion of the prefix based on the current history,
    // but this is just a dummy plugin that returns random suggestions.
    //
    Prediction result;

    result.addSuggestion (Suggestion("foo1", 0.99));
    result.addSuggestion (Suggestion("foo2", 0.98));
    result.addSuggestion (Suggestion("foo3", 0.97));
    result.addSuggestion (Suggestion("foo4", 0.96));
    result.addSuggestion (Suggestion("foo5", 0.95));
    result.addSuggestion (Suggestion("foo6", 0.94));

    result.addSuggestion (Suggestion("bar1", 0.89));
    result.addSuggestion (Suggestion("bar2", 0.88));
    result.addSuggestion (Suggestion("bar3", 0.87));
    result.addSuggestion (Suggestion("bar4", 0.86));
    result.addSuggestion (Suggestion("bar5", 0.85));
    result.addSuggestion (Suggestion("bar6", 0.84));

    result.addSuggestion (Suggestion("foobar1", 0.79));
    result.addSuggestion (Suggestion("foobar2", 0.78));
    result.addSuggestion (Suggestion("foobar3", 0.77));
    result.addSuggestion (Suggestion("foobar4", 0.76));
    result.addSuggestion (Suggestion("foobar5", 0.75));
    result.addSuggestion (Suggestion("foobar6", 0.74));

    return result;
}
Exemplo n.º 2
0
Prediction DictionaryPlugin::predict(const size_t max_partial_predictions_size, const char** filter) const
{
    Prediction result;

    std::string candidate;
    std::string prefix = contextTracker->getPrefix();

    std::ifstream dictionary_file;
    dictionary_file.open(dictionary_path.c_str());
    if(!dictionary_file)
        logger << ERROR << "Error opening dictionary: " << dictionary_path << endl;
    assert(dictionary_file); // REVISIT: handle with exceptions

    // scan file entries until we get enough suggestions
    unsigned int count = 0;
    while(dictionary_file >> candidate && count < max_partial_predictions_size) {
	if(candidate.find(prefix) == 0) {
	    result.addSuggestion(Suggestion(candidate,probability));
	    count++;
	    logger << NOTICE << "Found valid token: " << candidate << endl;
	} else {
	    logger << INFO << "Discarding invalid token: " << candidate << endl;
	}
    }

    dictionary_file.close();

    return result;
}
/** SQLite callback function
    Builds prediction from query results.

*/
int buildPrediction( void* callbackDataPtr,
		     int argc,
		     char** argv,
		     char** column )
{
	// cast pointer to void back to pointer to CallbackData object
	CallbackData* dataPtr = static_cast<CallbackData*>(callbackDataPtr);

	Prediction* predictionPtr = dataPtr->predPtr;
	size_t maxPredictionSize = dataPtr->predSize;

	if (predictionPtr->size() > maxPredictionSize) {
		return 1;
	} else {

		if( argc == 2 &&
		    strcmp( "word", column[ 0 ] ) == 0 &&
		    strcmp( "count", column[ 1 ] ) == 0 ) {
			
			predictionPtr->addSuggestion( 
				Suggestion( argv[ argc - 2 ],
					    atof( argv[ argc - 1 ] )
					)
				);
			
		} else {
			std::cerr << "Invalid invocation of buildPrediction method!"
				  << std::endl;
			exit( 1 );
		}
	}
	return 0;
}
Exemplo n.º 4
0
Prediction RecencyPredictor::predict (const size_t max, const char** filter) const
{
    Prediction result;

    std::string prefix = contextTracker->getPrefix();
    logger << INFO << "prefix: " << prefix << endl;
    if (!prefix.empty()) {
        // Only build recency prediction if prefix is not empty: when
        // prefix is empty, all previosly seen tokens are candidates
        // for prediction. This is not desirable, because it means
        // that recency prediction reduces to repetion of max previous
        // tokens (i.e. the prediction would contain the most recent
        // tokens in reverse order).
        //
        Suggestion  suggestion;
        size_t      index = 1;
        std::string token = contextTracker->getToken(index);
	double      prob = 0;
        while (!token.empty()                // context history exhausted
	       && result.size() < max        // need only max suggestions
	       && index <= cutoff_threshold  // look back only as far as cutoff
	    ) {
	    logger << INFO << "token: " << token << endl;

            if (token.find(prefix) == 0) { // if token starts with prefix

		if (token_satisfies_filter (token, prefix, filter)) {
		    // compute probability according to exponential decay
		    // formula
		    //
		    prob = n_0 * exp(-(lambda * (index - 1)));
		    logger << INFO << "probability: " << prob << endl;
		    suggestion.setWord(token);
		    suggestion.setProbability(prob);
		    result.addSuggestion(suggestion);
		}

            }

            index++;
            token = contextTracker->getToken(index);
        }
    }

    return result;
}
Exemplo n.º 5
0
void DejavuPluginTest::testPredict()
{
    *stream << "polly wants a cracker ";
    ct->update();

    // get pointer to dejavu plugin
    Plugin* plugin = pluginRegistry->iterator().next();

    {
        *stream << "polly ";
        Prediction expected;
        CPPUNIT_ASSERT_EQUAL(expected, plugin->predict(SIZE, 0));
        ct->update();
    }

    {
        *stream << "wants ";
        Prediction expected;
        CPPUNIT_ASSERT_EQUAL(expected, plugin->predict(SIZE, 0));
        ct->update();
    }

    {
        *stream << "a ";
        Prediction expected;
        expected.addSuggestion(Suggestion("cracker", 1.0));
        CPPUNIT_ASSERT_EQUAL(expected, plugin->predict(SIZE, 0));
        ct->update();
    }

    *stream << "soda ";
    ct->update();

    {
        *stream << "polly ";
        Prediction expected;
        CPPUNIT_ASSERT_EQUAL(expected, plugin->predict(SIZE, 0));
        ct->update();
    }

    {
        *stream << "wants ";
        Prediction expected;
        CPPUNIT_ASSERT_EQUAL(expected, plugin->predict(SIZE, 0));
        ct->update();
    }

    {
        *stream << "a ";
        Prediction expected;
        expected.addSuggestion(Suggestion("cracker", 1.0));
        expected.addSuggestion(Suggestion("soda",    1.0));
        CPPUNIT_ASSERT_EQUAL(expected, plugin->predict(SIZE, 0));
        ct->update();
    }

    *stream << "cake ";
    ct->update();

    {
        *stream << "polly ";
        Prediction expected;
        CPPUNIT_ASSERT_EQUAL(expected, plugin->predict(SIZE, 0));
        ct->update();
    }

    {
        *stream << "wants ";
        Prediction expected;
        CPPUNIT_ASSERT_EQUAL(expected, plugin->predict(SIZE, 0));
        ct->update();
    }

    {
        *stream << "a ";
        Prediction expected;
        expected.addSuggestion(Suggestion("cake",    1.0));
        expected.addSuggestion(Suggestion("cracker", 1.0));
        expected.addSuggestion(Suggestion("soda",    1.0));
        CPPUNIT_ASSERT_EQUAL(expected, plugin->predict(SIZE, 0));
        ct->update();
    }
}
Exemplo n.º 6
0
void SelectorTest::TestDataSuite_S6_NR_T0::init()
{
    TestData* td;
    Prediction* ip;
    Prediction* op;

    td = new TestData;
    ip = new Prediction;
    op = new Prediction;
    ip->addSuggestion(Suggestion("foo",        0.9));
    ip->addSuggestion(Suggestion("foo1",       0.8));
    ip->addSuggestion(Suggestion("foo2",       0.7));
    ip->addSuggestion(Suggestion("foo3",       0.6));
    ip->addSuggestion(Suggestion("foo4",       0.5));
    ip->addSuggestion(Suggestion("foo5",       0.4));
    ip->addSuggestion(Suggestion("foo6",       0.3));
    ip->addSuggestion(Suggestion("foobar",     0.2));
    ip->addSuggestion(Suggestion("foobar1",    0.1));
    ip->addSuggestion(Suggestion("foobar2",    0.09));
    ip->addSuggestion(Suggestion("foobar3",    0.08));
    ip->addSuggestion(Suggestion("foobar4",    0.07));
    ip->addSuggestion(Suggestion("foobar5",    0.06));
    ip->addSuggestion(Suggestion("foobar6",    0.05));
    ip->addSuggestion(Suggestion("foobar7",    0.04));
    ip->addSuggestion(Suggestion("foobar8",    0.03));
    ip->addSuggestion(Suggestion("foobar9",    0.02));
    ip->addSuggestion(Suggestion("foobarfoo",  0.01));
    ip->addSuggestion(Suggestion("foobarfoo1", 0.009));
    ip->addSuggestion(Suggestion("foobarfoo2", 0.008));
    ip->addSuggestion(Suggestion("foobarfoo3", 0.007));
    op->addSuggestion(Suggestion("foo",     0.9));
    op->addSuggestion(Suggestion("foo1",    0.8));
    op->addSuggestion(Suggestion("foo2",    0.7));
    op->addSuggestion(Suggestion("foo3",    0.6));
    op->addSuggestion(Suggestion("foo4",    0.5));
    op->addSuggestion(Suggestion("foo5",    0.4));
    td->updateString     = "f";
    td->inputPrediction  = *ip;
    td->outputPrediction = *op;
    testData.push_back(*td);
    delete td;
    delete op;

    td = new TestData;
    op = new Prediction;
    op->addSuggestion(Suggestion("foo6",       0.3));
    op->addSuggestion(Suggestion("foobar",     0.2));
    op->addSuggestion(Suggestion("foobar1",    0.1));
    op->addSuggestion(Suggestion("foobar2",    0.09));
    op->addSuggestion(Suggestion("foobar3",    0.08));
    op->addSuggestion(Suggestion("foobar4",    0.07));
    td->updateString     = "o";
    td->inputPrediction  = *ip;
    td->outputPrediction = *op;
    testData.push_back(*td);
    delete td;
    delete op;
    
    td = new TestData;
    op = new Prediction;
    op->addSuggestion(Suggestion("foobar5",    0.06));
    op->addSuggestion(Suggestion("foobar6",    0.05));
    op->addSuggestion(Suggestion("foobar7",    0.04));
    op->addSuggestion(Suggestion("foobar8",    0.03));
    op->addSuggestion(Suggestion("foobar9",    0.02));
    op->addSuggestion(Suggestion("foobarfoo",  0.01));
    td->updateString     = "o";
    td->inputPrediction  = *ip;
    td->outputPrediction = *op;
    testData.push_back(*td);
    delete td;
    delete op;

    iter = testData.begin();
}
Prediction SmoothedCountPlugin::predict(const size_t max_partial_predictions_size, const char** filter) const
{
    // get w_2, w_1, and prefix from HistoryTracker object
    std::string prefix = strtolower( contextTracker->getPrefix() );
    std::string word_1 = strtolower( contextTracker->getToken(1) );
    std::string word_2 = strtolower( contextTracker->getToken(2) );
    
    std::string query; // string used to build sql query
    int result;        // database interrogation diagnostic
    CallbackData data; // data to pass through to callback function
	

    // get most likely unigrams whose w contains prefix
    Prediction predUnigrams;
    
    data.predPtr = &predUnigrams;
    data.predSize = MAX_PARTIAL_PREDICTION_SIZE;
    
    query = 
	"SELECT word, count "
	"FROM _1_gram "
	"WHERE word LIKE \"" + prefix + "%\" "
	"ORDER BY count DESC;";
    
#if defined(HAVE_SQLITE3_H)
    result = sqlite3_exec(
#elif defined(HAVE_SQLITE_H)
    result = sqlite_exec(
#endif
	db,
	query.c_str(),
	buildPrediction,
	&data,
	NULL
    );
    assert(result == SQLITE_OK);


    // get most likely bigrams having matching w_1 whose w contains prefix
    Prediction predBigrams;
    
    data.predPtr = &predBigrams;
    
    query = 
    "SELECT word, count "
    "FROM _2_gram "
    "WHERE word_1 = \"" + word_1 + "\" "
    "AND word LIKE \"" + prefix + "\" "
    "ORDER BY count DESC;";
    
#if defined(HAVE_SQLITE3_H)
    result = sqlite3_exec(
#elif defined(HAVE_SQLITE_H)
    result = sqlite_exec(
#endif
	db,
	query.c_str(),
	buildPrediction,
	&data,
	NULL
    );
    assert(result == SQLITE_OK);


    // get most likely trigrams having matching w_2, w_1 whose w contains prefix
    Prediction predTrigrams;
    
    data.predPtr = &predTrigrams;
    
    query = 
    "SELECT word, count "
    "FROM _3_gram "
    "WHERE word_2 = \"" + word_2 + "\" "
    "AND word_1 = \"" + word_1 + "\" "
    "AND word LIKE \"" + prefix + "\" "
    "ORDER BY count DESC;";
    
#if defined(HAVE_SQLITE3_H)
    result = sqlite3_exec(
#elif defined(HAVE_SQLITE_H)
    result = sqlite_exec(
#endif
	db,
	query.c_str(),
	buildPrediction,
	&data,
	NULL
    );
    assert(result == SQLITE_OK);
	
    
    Prediction p;     // combined result of uni/bi/tri gram predictions
    std::string word; // pivot unigram word (used in next for loop)
    double ccount;    // combined count
    
    // compute smoothed probability estimation
    
    // TODO !!!!!!!! Everything should be scaled down to probabilities!!!
    // TODO That means that counts should be scaled down to values between
    // TODO 0 and 1. We need total word count to do that.
    
    // TODO : after correct word has been found in inner loops, execution
    // TODO : can break out of it.
    for (size_t i = 0; i < predUnigrams.size(); i++) {

	word   = predUnigrams.getSuggestion( i ).getWord();
	ccount = unigram_weight *
	    predUnigrams.getSuggestion( i ).getProbability();
	
	for (size_t j = 0; j < predBigrams.size(); j++) {

	    if( predBigrams.getSuggestion(j).getWord() == word ) {
		
		for (size_t k = 0; k < predTrigrams.size(); k++ ) {
		    
		    if( predTrigrams.getSuggestion(k).getWord() == word ) {
			
			ccount += trigram_weight *
			    predTrigrams.getSuggestion(k).getProbability();
			
		    }
		}
		
		ccount += bigram_weight *
		    predBigrams.getSuggestion(j).getProbability();
		
	    }
	    
	}
	
	p.addSuggestion( Suggestion( word, ccount ) );
	
    }

    return p; // Return combined prediction
}
Prediction SmoothedNgramPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const
{
    logger << DEBUG << "predict()" << endl;

    // Result prediction
    Prediction prediction;

    // Cache all the needed tokens.
    // tokens[k] corresponds to w_{i-k} in the generalized smoothed
    // n-gram probability formula
    //
    std::vector<std::string> tokens(cardinality);
    for (int i = 0; i < cardinality; i++) {
	tokens[cardinality - 1 - i] = contextTracker->getToken(i);
	logger << DEBUG << "Cached tokens[" << cardinality - 1 - i << "] = " << tokens[cardinality - 1 - i] << endl;
    }

    // Generate list of prefix completition candidates.
    //
    // The prefix completion candidates used to be obtained from the
    // _1_gram table because in a well-constructed ngram database the
    // _1_gram table (which contains all known tokens). However, this
    // introduced a skew, since the unigram counts will take
    // precedence over the higher-order counts.
    //
    // The current solution retrieves candidates from the highest
    // n-gram table, falling back on lower order n-gram tables if
    // initial completion set is smaller than required.
    //
    std::vector<std::string> prefixCompletionCandidates;
    for (size_t k = cardinality; (k > 0 && prefixCompletionCandidates.size() < max_partial_prediction_size); k--) {
        logger << DEBUG << "Building partial prefix completion table of cardinality: " << k << endl;
        // create n-gram used to retrieve initial prefix completion table
        Ngram prefix_ngram(k);
        copy(tokens.end() - k, tokens.end(), prefix_ngram.begin());

	if (logger.shouldLog()) {
	    logger << DEBUG << "prefix_ngram: ";
	    for (size_t r = 0; r < prefix_ngram.size(); r++) {
		logger << DEBUG << prefix_ngram[r] << ' ';
	    }
	    logger << DEBUG << endl;
	}

        // obtain initial prefix completion candidates
        db->beginTransaction();

        NgramTable partial;

        if (filter == 0) {
	    partial = db->getNgramLikeTable(prefix_ngram,max_partial_prediction_size - prefixCompletionCandidates.size());
	} else {
	    partial = db->getNgramLikeTableFiltered(prefix_ngram,filter, max_partial_prediction_size - prefixCompletionCandidates.size());
	}

        db->endTransaction();

	if (logger.shouldLog()) {
	    logger << DEBUG << "partial prefixCompletionCandidates" << endl
	           << DEBUG << "----------------------------------" << endl;
	    for (size_t j = 0; j < partial.size(); j++) {
		for (size_t k = 0; k < partial[j].size(); k++) {
		    logger << DEBUG << partial[j][k] << " ";
		}
		logger << endl;
	    }
	}

        logger << DEBUG << "Partial prefix completion table contains " << partial.size() << " potential completions." << endl;

        // append newly discovered potential completions to prefix
        // completion candidates array to fill it up to
        // max_partial_prediction_size
        //
        std::vector<Ngram>::const_iterator it = partial.begin();
        while (it != partial.end() && prefixCompletionCandidates.size() < max_partial_prediction_size) {
            // only add new candidates, iterator it points to Ngram,
            // it->end() - 2 points to the token candidate
            //
            std::string candidate = *(it->end() - 2);
            if (find(prefixCompletionCandidates.begin(),
                     prefixCompletionCandidates.end(),
                     candidate) == prefixCompletionCandidates.end()) {
                prefixCompletionCandidates.push_back(candidate);
            }
            it++;
        }
    }

    if (logger.shouldLog()) {
	logger << DEBUG << "prefixCompletionCandidates" << endl
	       << DEBUG << "--------------------------" << endl;
	for (size_t j = 0; j < prefixCompletionCandidates.size(); j++) {
	    logger << DEBUG << prefixCompletionCandidates[j] << endl;
	}
    }

    // compute smoothed probabilities for all candidates
    //
    db->beginTransaction();
    // getUnigramCountsSum is an expensive SQL query
    // caching it here saves much time later inside the loop
    int unigrams_counts_sum = db->getUnigramCountsSum(); 
    for (size_t j = 0; (j < prefixCompletionCandidates.size() && j < max_partial_prediction_size); j++) {
        // store w_i candidate at end of tokens
        tokens[cardinality - 1] = prefixCompletionCandidates[j];

	logger << DEBUG << "------------------" << endl;
	logger << DEBUG << "w_i: " << tokens[cardinality - 1] << endl;

	double probability = 0;
	for (int k = 0; k < cardinality; k++) {
	    double numerator = count(tokens, 0, k+1);
	    // reuse cached unigrams_counts_sum to speed things up
	    double denominator = (k == 0 ? unigrams_counts_sum : count(tokens, -1, k));
	    double frequency = ((denominator > 0) ? (numerator / denominator) : 0);
	    probability += deltas[k] * frequency;

	    logger << DEBUG << "numerator:   " << numerator << endl;
	    logger << DEBUG << "denominator: " << denominator << endl;
	    logger << DEBUG << "frequency:   " << frequency << endl;
	    logger << DEBUG << "delta:       " << deltas[k] << endl;

            // for some sanity checks
	    assert(numerator <= denominator);
	    assert(frequency <= 1);
	}

        logger << DEBUG << "____________" << endl;
        logger << DEBUG << "probability: " << probability << endl;

	if (probability > 0) {
	    prediction.addSuggestion(Suggestion(tokens[cardinality - 1], probability));
	}
    }
    db->endTransaction();

    logger << DEBUG << "Prediction:" << endl;
    logger << DEBUG << "-----------" << endl;
    logger << DEBUG << prediction << endl;

    return prediction;
}
Exemplo n.º 9
0
Prediction ARPAPlugin::predict(const size_t max_partial_prediction_size, const char** filter) const
{
  logger << DEBUG << "predict()" << endl;
  Prediction prediction;

  int cardinality = 3;
  std::vector<std::string> tokens(cardinality);

  std::string prefix = strtolower(contextTracker->getToken(0));
  std::string wd2Str = strtolower(contextTracker->getToken(1));
  std::string wd1Str = strtolower(contextTracker->getToken(2));

  std::multimap<float,std::string,cmp> result;

  logger << DEBUG << "["<<wd1Str<<"]"<<" ["<<wd2Str<<"] "<<"["<<prefix<<"]"<<endl;

  //search for the past tokens in the vocabulary
  std::map<std::string,int>::const_iterator wd1It,wd2It;
  wd1It = vocabCode.find(wd1Str);
  wd2It = vocabCode.find(wd2Str);

  /**
   * note if we have not tokens to compute 3-gram probabilities we compute 2-gram or 1-gram probabilities.
   * the following code might be repetitive but more efficient than having the main loop outside.
   */

  //we have two valid past tokens available
  if(wd1It!=vocabCode.end() && wd2It!=vocabCode.end())
  {
    //iterate over all vocab words
    for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); it++)
    {
      //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set
      if(matchesPrefixAndFilter(it->second,prefix,filter))
      {
        std::pair<float,std::string> p;
        p.first = computeTrigramBackoff(wd1It->second,wd2It->second,it->first);
        p.second = it->second;
        result.insert(p);
      }
    }
  }

  //we have one valid past token available
  else if(wd2It!=vocabCode.end())
  {
    //iterate over all vocab words
    for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); it++)
    {
      //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set
      if(matchesPrefixAndFilter(it->second,prefix,filter))
      {
        std::pair<float,std::string> p;
        p.first = computeBigramBackoff(wd2It->second,it->first);
        p.second = it->second;
        result.insert(p);
      }
    }
  }

  //we have no valid past token available
  else
  {
    //iterate over all vocab words
    for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); it++)
    {
      //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set
      if(matchesPrefixAndFilter(it->second,prefix,filter))
      {
        std::pair<float,std::string> p;
        p.first = unigramMap.find(it->first)->second.logProb;
        p.second = it->second;
        result.insert(p);
      }
    }
  }


  size_t numSuggestions = 0;
  for(std::multimap<float,std::string>::const_iterator it = result.begin(); it != result.end() && numSuggestions < max_partial_prediction_size; ++it)
  {
    prediction.addSuggestion(Suggestion(it->second,exp(it->first)));
    numSuggestions++;
  }

  return prediction;
}
Exemplo n.º 10
0
void DejavuPredictorTest::testPredict()
{
    *stream << "polly wants a cracker ";
    ct->update();

    // get pointer to dejavu predictor
    Predictor* predictor = predictorRegistry->iterator().next();
    
    {
        *stream << "polly ";
        Prediction expected;
        CPPUNIT_ASSERT_EQUAL(expected, predictor->predict(SIZE, 0));
        ct->update();
    }

    {
        *stream << "wants ";
        Prediction expected;
        CPPUNIT_ASSERT_EQUAL(expected, predictor->predict(SIZE, 0));
        ct->update();
    }

    {
        *stream << "a ";
        Prediction expected;
        expected.addSuggestion(Suggestion("cracker", 1.0));
        CPPUNIT_ASSERT_EQUAL(expected, predictor->predict(SIZE, 0));
        ct->update();
    }

    *stream << "soda ";
    ct->update();

    {
        *stream << "polly ";
        Prediction expected;
        CPPUNIT_ASSERT_EQUAL(expected, predictor->predict(SIZE, 0));
        ct->update();
    }

    {
        *stream << "wants ";
        Prediction expected;
        CPPUNIT_ASSERT_EQUAL(expected, predictor->predict(SIZE, 0));
        ct->update();
    }

    {
        *stream << "a ";
        Prediction expected;
        expected.addSuggestion(Suggestion("cracker", 1.0));
        expected.addSuggestion(Suggestion("soda",    1.0));
        CPPUNIT_ASSERT_EQUAL(expected, predictor->predict(SIZE, 0));
        ct->update();
    }

    *stream << "cake ";
    ct->update();

    {
        *stream << "polly ";
        Prediction expected;
        CPPUNIT_ASSERT_EQUAL(expected, predictor->predict(SIZE, 0));
        ct->update();
    }

    {
        *stream << "wants ";
        Prediction expected;
        CPPUNIT_ASSERT_EQUAL(expected, predictor->predict(SIZE, 0));
        ct->update();
    }

    {
        *stream << "a ";
        Prediction expected;
        expected.addSuggestion(Suggestion("cake",    1.0));
        expected.addSuggestion(Suggestion("cracker", 1.0));
        expected.addSuggestion(Suggestion("soda",    1.0));
        CPPUNIT_ASSERT_EQUAL(expected, predictor->predict(SIZE, 0));
        ct->update();
    }

    *stream << "crumble ";
    ct->update();

    {
        // test filter
        const char* filter[] = { "cra", "so", 0 };

        *stream << "polly wants a ";
        Prediction expected;
        expected.addSuggestion(Suggestion("cracker",    1.0));
        expected.addSuggestion(Suggestion("soda",    1.0));
        CPPUNIT_ASSERT_EQUAL(expected, predictor->predict(SIZE, filter));
        ct->update();
    }

    *stream << "break ";
    ct->update();

    {
        // test filter
        const char* filter[] = { "r", 0 };

        *stream << "polly wants a c";
        Prediction expected;
        expected.addSuggestion(Suggestion("cracker",    1.0));
        expected.addSuggestion(Suggestion("crumble",    1.0));
        CPPUNIT_ASSERT_EQUAL(expected, predictor->predict(SIZE, filter));
        ct->update();
    }

    *stream << "uddle ";
    ct->update();
}