void DatabaseConnectorTest::assertCorrectMockNgramTable(NgramTable ngramTable) { CPPUNIT_ASSERT_EQUAL(size_t(3), ngramTable.size()); CPPUNIT_ASSERT(ngramTable[0][0] == "foo"); CPPUNIT_ASSERT(ngramTable[0][1] == "bar"); CPPUNIT_ASSERT(ngramTable[0][2] == "foobar"); CPPUNIT_ASSERT(ngramTable[0][3] == "3"); CPPUNIT_ASSERT(ngramTable[1][0] == "bar"); CPPUNIT_ASSERT(ngramTable[1][1] == "foo"); CPPUNIT_ASSERT(ngramTable[1][2] == "foobar"); CPPUNIT_ASSERT(ngramTable[1][3] == "33"); CPPUNIT_ASSERT(ngramTable[2][0] == "foobar"); CPPUNIT_ASSERT(ngramTable[2][1] == "bar"); CPPUNIT_ASSERT(ngramTable[2][2] == "foo"); CPPUNIT_ASSERT(ngramTable[2][3] == "333"); }
Prediction SmoothedNgramPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const { logger << DEBUG << "predict()" << endl; // Result prediction Prediction prediction; // Cache all the needed tokens. // tokens[k] corresponds to w_{i-k} in the generalized smoothed // n-gram probability formula // std::vector<std::string> tokens(cardinality); for (int i = 0; i < cardinality; i++) { tokens[cardinality - 1 - i] = contextTracker->getToken(i); logger << DEBUG << "Cached tokens[" << cardinality - 1 - i << "] = " << tokens[cardinality - 1 - i] << endl; } // Generate list of prefix completition candidates. // // The prefix completion candidates used to be obtained from the // _1_gram table because in a well-constructed ngram database the // _1_gram table (which contains all known tokens). However, this // introduced a skew, since the unigram counts will take // precedence over the higher-order counts. // // The current solution retrieves candidates from the highest // n-gram table, falling back on lower order n-gram tables if // initial completion set is smaller than required. // std::vector<std::string> prefixCompletionCandidates; for (size_t k = cardinality; (k > 0 && prefixCompletionCandidates.size() < max_partial_prediction_size); k--) { logger << DEBUG << "Building partial prefix completion table of cardinality: " << k << endl; // create n-gram used to retrieve initial prefix completion table Ngram prefix_ngram(k); copy(tokens.end() - k, tokens.end(), prefix_ngram.begin()); if (logger.shouldLog()) { logger << DEBUG << "prefix_ngram: "; for (size_t r = 0; r < prefix_ngram.size(); r++) { logger << DEBUG << prefix_ngram[r] << ' '; } logger << DEBUG << endl; } // obtain initial prefix completion candidates db->beginTransaction(); NgramTable partial; if (filter == 0) { partial = db->getNgramLikeTable(prefix_ngram,max_partial_prediction_size - prefixCompletionCandidates.size()); } else { partial = db->getNgramLikeTableFiltered(prefix_ngram,filter, max_partial_prediction_size - prefixCompletionCandidates.size()); } db->endTransaction(); if (logger.shouldLog()) { logger << DEBUG << "partial prefixCompletionCandidates" << endl << DEBUG << "----------------------------------" << endl; for (size_t j = 0; j < partial.size(); j++) { for (size_t k = 0; k < partial[j].size(); k++) { logger << DEBUG << partial[j][k] << " "; } logger << endl; } } logger << DEBUG << "Partial prefix completion table contains " << partial.size() << " potential completions." << endl; // append newly discovered potential completions to prefix // completion candidates array to fill it up to // max_partial_prediction_size // std::vector<Ngram>::const_iterator it = partial.begin(); while (it != partial.end() && prefixCompletionCandidates.size() < max_partial_prediction_size) { // only add new candidates, iterator it points to Ngram, // it->end() - 2 points to the token candidate // std::string candidate = *(it->end() - 2); if (find(prefixCompletionCandidates.begin(), prefixCompletionCandidates.end(), candidate) == prefixCompletionCandidates.end()) { prefixCompletionCandidates.push_back(candidate); } it++; } } if (logger.shouldLog()) { logger << DEBUG << "prefixCompletionCandidates" << endl << DEBUG << "--------------------------" << endl; for (size_t j = 0; j < prefixCompletionCandidates.size(); j++) { logger << DEBUG << prefixCompletionCandidates[j] << endl; } } // compute smoothed probabilities for all candidates // db->beginTransaction(); // getUnigramCountsSum is an expensive SQL query // caching it here saves much time later inside the loop int unigrams_counts_sum = db->getUnigramCountsSum(); for (size_t j = 0; (j < prefixCompletionCandidates.size() && j < max_partial_prediction_size); j++) { // store w_i candidate at end of tokens tokens[cardinality - 1] = prefixCompletionCandidates[j]; logger << DEBUG << "------------------" << endl; logger << DEBUG << "w_i: " << tokens[cardinality - 1] << endl; double probability = 0; for (int k = 0; k < cardinality; k++) { double numerator = count(tokens, 0, k+1); // reuse cached unigrams_counts_sum to speed things up double denominator = (k == 0 ? unigrams_counts_sum : count(tokens, -1, k)); double frequency = ((denominator > 0) ? (numerator / denominator) : 0); probability += deltas[k] * frequency; logger << DEBUG << "numerator: " << numerator << endl; logger << DEBUG << "denominator: " << denominator << endl; logger << DEBUG << "frequency: " << frequency << endl; logger << DEBUG << "delta: " << deltas[k] << endl; // for some sanity checks assert(numerator <= denominator); assert(frequency <= 1); } logger << DEBUG << "____________" << endl; logger << DEBUG << "probability: " << probability << endl; if (probability > 0) { prediction.addSuggestion(Suggestion(tokens[cardinality - 1], probability)); } } db->endTransaction(); logger << DEBUG << "Prediction:" << endl; logger << DEBUG << "-----------" << endl; logger << DEBUG << prediction << endl; return prediction; }