Ejemplo n.º 1
0
void KENLM<Model>::EvaluateWhenApplied(const ManagerBase &mgr,
                                       const Hypothesis &hypo, const FFState &prevState, Scores &scores,
                                       FFState &state) const
{
  KenLMState &stateCast = static_cast<KenLMState&>(state);

  const System &system = mgr.system;

  const lm::ngram::State &in_state =
    static_cast<const KenLMState&>(prevState).state;

  if (!hypo.GetTargetPhrase().GetSize()) {
    stateCast.state = in_state;
    return;
  }

  const std::size_t begin = hypo.GetCurrTargetWordsRange().GetStartPos();
  //[begin, end) in STL-like fashion.
  const std::size_t end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1;
  const std::size_t adjust_end = std::min(end, begin + m_ngram->Order() - 1);

  std::size_t position = begin;
  typename Model::State aux_state;
  typename Model::State *state0 = &stateCast.state, *state1 = &aux_state;

  float score = m_ngram->Score(in_state, TranslateID(hypo.GetWord(position)),
                               *state0);
  ++position;
  for (; position < adjust_end; ++position) {
    score += m_ngram->Score(*state0, TranslateID(hypo.GetWord(position)),
                            *state1);
    std::swap(state0, state1);
  }

  if (hypo.GetBitmap().IsComplete()) {
    // Score end of sentence.
    std::vector<lm::WordIndex> indices(m_ngram->Order() - 1);
    const lm::WordIndex *last = LastIDs(hypo, &indices.front());
    score += m_ngram->FullScoreForgotState(&indices.front(), last,
                                           m_ngram->GetVocabulary().EndSentence(), stateCast.state).prob;
  } else if (adjust_end < end) {
    // Get state after adding a long phrase.
    std::vector<lm::WordIndex> indices(m_ngram->Order() - 1);
    const lm::WordIndex *last = LastIDs(hypo, &indices.front());
    m_ngram->GetState(&indices.front(), last, stateCast.state);
  } else if (state0 != &stateCast.state) {
    // Short enough phrase that we can just reuse the state.
    stateCast.state = *state0;
  }

  score = TransformLMScore(score);

  bool OOVFeatureEnabled = false;
  if (OOVFeatureEnabled) {
    std::vector<float> scoresVec(2);
    scoresVec[0] = score;
    scoresVec[1] = 0.0;
    scores.PlusEquals(system, *this, scoresVec);
  } else {
    scores.PlusEquals(system, *this, score);
  }
}
void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComponentCollection* accumulator) const
{
  const Sentence& input = *(m_local->input);
  const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase();

  for(size_t targetIndex = 0; targetIndex < targetPhrase.GetSize(); targetIndex++ ) {
    StringPiece targetString = targetPhrase.GetWord(targetIndex).GetString(0); // TODO: change for other factors

    if (m_ignorePunctuation) {
      // check if first char is punctuation
      char firstChar = targetString[0];
      CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
      if(charIterator != m_punctuationHash.end())
        continue;
    }

    if (m_biasFeature) {
      stringstream feature;
      feature << "glm_";
      feature << targetString;
      feature << "~";
      feature << "**BIAS**";
      accumulator->SparsePlusEquals(feature.str(), 1);
    }

    boost::unordered_set<uint64_t> alreadyScored;
    for(size_t sourceIndex = 0; sourceIndex < input.GetSize(); sourceIndex++ ) {
      const StringPiece sourceString = input.GetWord(sourceIndex).GetString(0);
      // TODO: change for other factors

      if (m_ignorePunctuation) {
        // check if first char is punctuation
        char firstChar = sourceString[0];
        CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
        if(charIterator != m_punctuationHash.end())
          continue;
      }
      const uint64_t sourceHash = util::MurmurHashNative(sourceString.data(), sourceString.size());

      if ( alreadyScored.find(sourceHash) == alreadyScored.end()) {
        bool sourceExists, targetExists;
        if (!m_unrestricted) {
          sourceExists = FindStringPiece(m_vocabSource, sourceString ) != m_vocabSource.end();
          targetExists = FindStringPiece(m_vocabTarget, targetString) != m_vocabTarget.end();
        }

        // no feature if vocab is in use and both words are not in restricted vocabularies
        if (m_unrestricted || (sourceExists && targetExists)) {
          if (m_sourceContext) {
            if (sourceIndex == 0) {
              // add <s> trigger feature for source
              stringstream feature;
              feature << "glm_";
              feature << targetString;
              feature << "~";
              feature << "<s>,";
              feature << sourceString;
              accumulator->SparsePlusEquals(feature.str(), 1);
              alreadyScored.insert(sourceHash);
            }

            // add source words to the right of current source word as context
            for(int contextIndex = sourceIndex+1; contextIndex < input.GetSize(); contextIndex++ ) {
              StringPiece contextString = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors
              bool contextExists;
              if (!m_unrestricted)
                contextExists = FindStringPiece(m_vocabSource, contextString ) != m_vocabSource.end();

              if (m_unrestricted || contextExists) {
                stringstream feature;
                feature << "glm_";
                feature << targetString;
                feature << "~";
                feature << sourceString;
                feature << ",";
                feature << contextString;
                accumulator->SparsePlusEquals(feature.str(), 1);
                alreadyScored.insert(sourceHash);
              }
            }
          } else if (m_biphrase) {
            // --> look backwards for constructing context
            int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex;

            // 1) source-target pair, trigger source word (can be discont.) and adjacent target word (bigram)
            StringPiece targetContext;
            if (globalTargetIndex > 0)
              targetContext = cur_hypo.GetWord(globalTargetIndex-1).GetString(0); // TODO: change for other factors
            else
              targetContext = "<s>";

            if (sourceIndex == 0) {
              StringPiece sourceTrigger = "<s>";
              AddFeature(accumulator, sourceTrigger, sourceString,
                         targetContext, targetString);
            } else
              for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) {
                StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors
                bool sourceTriggerExists = false;
                if (!m_unrestricted)
                  sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();

                if (m_unrestricted || sourceTriggerExists)
                  AddFeature(accumulator, sourceTrigger, sourceString,
                             targetContext, targetString);
              }

            // 2) source-target pair, adjacent source word (bigram) and trigger target word (can be discont.)
            StringPiece sourceContext;
            if (sourceIndex-1 >= 0)
              sourceContext = input.GetWord(sourceIndex-1).GetString(0); // TODO: change for other factors
            else
              sourceContext = "<s>";

            if (globalTargetIndex == 0) {
              string targetTrigger = "<s>";
              AddFeature(accumulator, sourceContext, sourceString,
                         targetTrigger, targetString);
            } else
              for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
                StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors
                bool targetTriggerExists = false;
                if (!m_unrestricted)
                  targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end();

                if (m_unrestricted || targetTriggerExists)
                  AddFeature(accumulator, sourceContext, sourceString,
                             targetTrigger, targetString);
              }
          } else if (m_bitrigger) {
            // allow additional discont. triggers on both sides
            int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex;

            if (sourceIndex == 0) {
              StringPiece sourceTrigger = "<s>";
              bool sourceTriggerExists = true;

              if (globalTargetIndex == 0) {
                string targetTrigger = "<s>";
                bool targetTriggerExists = true;

                if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
                  AddFeature(accumulator, sourceTrigger, sourceString,
                             targetTrigger, targetString);
              } else {
                // iterate backwards over target
                for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
                  StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors
                  bool targetTriggerExists = false;
                  if (!m_unrestricted)
                    targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end();

                  if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
                    AddFeature(accumulator, sourceTrigger, sourceString,
                               targetTrigger, targetString);
                }
              }
            }
            // iterate over both source and target
            else {
              // iterate backwards over source
              for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) {
                StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors
                bool sourceTriggerExists = false;
                if (!m_unrestricted)
                  sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();

                if (globalTargetIndex == 0) {
                  string targetTrigger = "<s>";
                  bool targetTriggerExists = true;

                  if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
                    AddFeature(accumulator, sourceTrigger, sourceString,
                               targetTrigger, targetString);
                } else {
                  // iterate backwards over target
                  for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
                    StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors
                    bool targetTriggerExists = false;
                    if (!m_unrestricted)
                      targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end();

                    if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
                      AddFeature(accumulator, sourceTrigger, sourceString,
                                 targetTrigger, targetString);
                  }
                }
              }
            }
          } else {
            stringstream feature;
            feature << "glm_";
            feature << targetString;
            feature << "~";
            feature << sourceString;
            accumulator->SparsePlusEquals(feature.str(), 1);
            alreadyScored.insert(sourceHash);

          }
        }
      }
    }
  }
}
Ejemplo n.º 3
0
size_t LM::Evaluate(
  const Hypothesis& hypo,
  size_t prevState,
  Scores &scores) const
{
  if (m_order <= 1) {
    return 0; // not sure if returning NULL is correct
  }

  if (hypo.targetPhrase.GetSize() == 0) {
    return 0; // not sure if returning NULL is correct
  }

    PhraseVec m_phraseVec(m_order);

  const size_t currEndPos = hypo.targetRange.endPos;
  const size_t startPos = hypo.targetRange.startPos;

  size_t index = 0;
  for (int currPos = (int) startPos - (int) m_order + 1 ; currPos <= (int) startPos ; currPos++) {
    if (currPos >= 0)
      m_phraseVec[index++] = &hypo.GetWord(currPos);
    else {
      m_phraseVec[index++] = &m_bos;
    }
  }

  SCORE lmScore = GetValueCache(m_phraseVec);

  // main loop
  size_t endPos = std::min(startPos + m_order - 2
                           , currEndPos);
  for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++) {
    // shift all args down 1 place
    for (size_t i = 0 ; i < m_order - 1 ; i++)
      m_phraseVec[i] = m_phraseVec[i + 1];

    // add last factor
    m_phraseVec.back() = &hypo.GetWord(currPos);

    lmScore	+= GetValueCache(m_phraseVec);
  }

  // end of sentence
  if (hypo.GetCoverage().IsComplete()) {
    const size_t size = hypo.GetSize();
    m_phraseVec.back() = &m_eos;

    for (size_t i = 0 ; i < m_order - 1 ; i ++) {
      int currPos = (int)(size - m_order + i + 1);
      if (currPos < 0)
        m_phraseVec[i] = &m_bos;
      else
        m_phraseVec[i] = &hypo.GetWord((size_t)currPos);
    }
    lmScore += GetValueCache(m_phraseVec);
  } else {
    if (endPos < currEndPos) {
      //need to get the LM state (otherwise the last LM state is fine)
      for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {
        for (size_t i = 0 ; i < m_order - 1 ; i++)
          m_phraseVec[i] = m_phraseVec[i + 1];
        m_phraseVec.back() = &hypo.GetWord(currPos);
      }
    }
  }

  size_t state = GetLastState();
  return state;
}