예제 #1
1
/***
 * print surface factor only for the given phrase
 */
void OutputSurface(std::ostream &out, const Hypothesis &edge, const std::vector<FactorType> &outputFactorOrder,
		   bool reportSegmentation, bool reportAllFactors)
{
  CHECK(outputFactorOrder.size() > 0);
  const Phrase& phrase = edge.GetCurrTargetPhrase();
  if (reportAllFactors == true) {
    out << phrase;
  } else {
    size_t size = phrase.GetSize();
    for (size_t pos = 0 ; pos < size ; pos++) {
      const Factor *factor = phrase.GetFactor(pos, outputFactorOrder[0]);
      out << *factor;
      CHECK(factor);

      for (size_t i = 1 ; i < outputFactorOrder.size() ; i++) {
        const Factor *factor = phrase.GetFactor(pos, outputFactorOrder[i]);
        CHECK(factor);

        out << "|" << *factor;
      }
      out << " ";
    }
  }

  // trace option "-t"
  if (reportSegmentation == true && phrase.GetSize() > 0) {
    out << "|" << edge.GetCurrSourceWordsRange().GetStartPos()
	<< "-" << edge.GetCurrSourceWordsRange().GetEndPos() << "| ";
  }
}
예제 #2
0
/***
 * print surface factor only for the given phrase
 */
void OutputSurface(std::ostream &out, const Hypothesis &edge, const std::vector<FactorType> &outputFactorOrder,
                   char reportSegmentation, bool reportAllFactors)
{
  CHECK(outputFactorOrder.size() > 0);
  const Phrase& phrase = edge.GetCurrTargetPhrase();
  bool markUnknown = StaticData::Instance().GetMarkUnknown();
  if (reportAllFactors == true) {
    out << phrase;
  } else {
    FactorType placeholderFactor = StaticData::Instance().GetPlaceholderFactor().second;

    size_t size = phrase.GetSize();
    for (size_t pos = 0 ; pos < size ; pos++) {
      const Factor *factor = phrase.GetFactor(pos, outputFactorOrder[0]);

      if (placeholderFactor != NOT_FOUND) {
        const Factor *origFactor = phrase.GetFactor(pos, placeholderFactor);
        if (origFactor) {
          factor = origFactor;
        }
      }
      CHECK(factor);

      //preface surface form with UNK if marking unknowns
      const Word &word = phrase.GetWord(pos);
      if(markUnknown && word.IsOOV()) {
	out << "UNK" << *factor;
      }
      else {
	out << *factor;
      }
      
      for (size_t i = 1 ; i < outputFactorOrder.size() ; i++) {
        const Factor *factor = phrase.GetFactor(pos, outputFactorOrder[i]);
        CHECK(factor);

        out << "|" << *factor;
      }
      out << " ";
    }
  }

  // trace option "-t" / "-tt"
  if (reportSegmentation > 0 && phrase.GetSize() > 0) {
    const WordsRange &sourceRange = edge.GetCurrSourceWordsRange();
    const int sourceStart = sourceRange.GetStartPos();
    const int sourceEnd = sourceRange.GetEndPos();
    out << "|" << sourceStart << "-" << sourceEnd;
    // enriched "-tt"
    if (reportSegmentation == 2) {
      out << ",0, ";
      const AlignmentInfo &ai = edge.GetCurrTargetPhrase().GetAlignTerm();
      OutputAlignment(out, ai, 0, 0);
    }
    out << "| ";
  }
}
예제 #3
0
void LexicalReordering::EmptyHypothesisState(FFState &state,
    const ManagerBase &mgr, const InputType &input,
    const Hypothesis &hypo) const
{
  BidirectionalReorderingState &stateCast =
      static_cast<BidirectionalReorderingState&>(state);
  stateCast.Init(NULL, hypo.GetTargetPhrase(), hypo.GetInputPath(), true,
      &hypo.GetBitmap());
}
/// add phrase alignment information from a Hypothesis
void
TranslationRequest::
add_phrase_aln_info(Hypothesis const& h, vector<xmlrpc_c::value>& aInfo) const
{
  if (!m_withAlignInfo) return;
  WordsRange const& trg = h.GetCurrTargetWordsRange();
  WordsRange const& src = h.GetCurrSourceWordsRange();

  std::map<std::string, xmlrpc_c::value> pAlnInfo;
  pAlnInfo["tgt-start"] = xmlrpc_c::value_int(trg.GetStartPos());
  pAlnInfo["src-start"] = xmlrpc_c::value_int(src.GetStartPos());
  pAlnInfo["src-end"]   = xmlrpc_c::value_int(src.GetEndPos());
  aInfo.push_back(xmlrpc_c::value_struct(pAlnInfo));
}
예제 #5
0
void LanguageModel::EvaluateWhenApplied(const ManagerBase &mgr,
                                        const Hypothesis &hypo, const FFState &prevState, Scores &scores,
                                        FFState &state) const
{
  const LMState &prevLMState = static_cast<const LMState &>(prevState);
  size_t numWords = prevLMState.numWords;

  // context is held backwards
  vector<const Factor*> context(numWords);
  for (size_t i = 0; i < numWords; ++i) {
    context[i] = prevLMState.lastWords[i];
  }
  //DebugContext(context);

  SCORE score = 0;
  std::pair<SCORE, void*> fromScoring;
  const TargetPhrase<Moses2::Word> &tp = hypo.GetTargetPhrase();
  for (size_t i = 0; i < tp.GetSize(); ++i) {
    const Word &word = tp[i];
    const Factor *factor = word[m_factorType];
    ShiftOrPush(context, factor);
    fromScoring = Score(context);
    score += fromScoring.first;
  }

  const Bitmap &bm = hypo.GetBitmap();
  if (bm.IsComplete()) {
    // everything translated
    ShiftOrPush(context, m_eos);
    fromScoring = Score(context);
    score += fromScoring.first;
    fromScoring.second = NULL;
    context.clear();
  } else {
    assert(context.size());
    if (context.size() == m_order) {
      context.resize(context.size() - 1);
    }
  }

  scores.PlusEquals(mgr.system, *this, score);

  // return state
  //DebugContext(context);

  LMState &stateCast = static_cast<LMState&>(state);
  MemPool &pool = mgr.GetPool();
  stateCast.Set(pool, fromScoring.second, context);
}
void GlobalLexicalModel::Evaluate
(const Hypothesis& hypo,
 ScoreComponentCollection* accumulator) const
{
  accumulator->PlusEquals( this,
                           GetFromCacheOrScorePhrase(hypo.GetCurrTargetPhrase()) );
}
예제 #7
0
void generate_hypotheses(const int order, const Hypothesis & h,
                         const vector<ME_Model> & vme,
                         list<Hypothesis> & vh)
{
  int n = h.vt.size();
  int pred_position = -1;
  double min_ent = 999999;
  string pred = "";
  double pred_prob = 0;
  for (int j = 0; j < n; j++) {
    if (h.vt[j].cprd != "") continue;
    double ent = h.vent[j];
    if (ent < min_ent) {
      //        pred = h.vvp[j].begin()->first;
      //        pred_prob = h.vvp[j].begin()->second;
      min_ent = ent;
      pred_position = j;
    }
  }
  assert(pred_position >= 0 && pred_position < n);

  for (vector<pair<string, double> >::const_iterator k = h.vvp[pred_position].begin();
       k != h.vvp[pred_position].end(); k++) {
    Hypothesis newh = h;

    newh.vt[pred_position].cprd = k->first;
    newh.order[pred_position] = order + 1;
    newh.prob = h.prob * k->second;

    //    if (newh.IsErroneous()) {
    //      cout << "*errorneous" << endl;
    //      newh.Print();
    //      continue;
    //    }

    // update the neighboring predictions
    for (int j = pred_position - TAG_WINDOW_SIZE; j <= pred_position + TAG_WINDOW_SIZE; j++) {
      if (j < 0 || j > n-1) continue;
      if (newh.vt[j].cprd == "") newh.Update(j, vme);
    }
    vh.push_back(newh);
  }


}
float
DistortionScoreProducer::
CalculateDistortionScore(const Hypothesis& hypo,
                         const Range &prev, const Range &curr, const int FirstGap)
{
  // if(!StaticData::Instance().UseEarlyDistortionCost()) {
  if(!hypo.GetManager().options()->reordering.use_early_distortion_cost) {
    return - (float) hypo.GetInput().ComputeDistortionDistance(prev, curr);
  } // else {

  /* Pay distortion score as soon as possible, from Moore and Quirk MT Summit 2007
     Definitions:
     S   : current source range
     S'  : last translated source phrase range
     S'' : longest fully-translated initial segment
  */

  int prefixEndPos = (int)FirstGap-1;
  if((int)FirstGap==-1)
    prefixEndPos = -1;

  // case1: S is adjacent to S'' => return 0
  if ((int) curr.GetStartPos() == prefixEndPos+1) {
    IFVERBOSE(4) std::cerr<< "MQ07disto:case1" << std::endl;
    return 0;
  }

  // case2: S is to the left of S' => return 2(length(S))
  if ((int) curr.GetEndPos() < (int) prev.GetEndPos()) {
    IFVERBOSE(4) std::cerr<< "MQ07disto:case2" << std::endl;
    return (float) -2*(int)curr.GetNumWordsCovered();
  }

  // case3: S' is a subsequence of S'' => return 2(nbWordBetween(S,S'')+length(S))
  if ((int) prev.GetEndPos() <= prefixEndPos) {
    IFVERBOSE(4) std::cerr<< "MQ07disto:case3" << std::endl;
    int z = (int)curr.GetStartPos()-prefixEndPos - 1;
    return (float) -2*(z + (int)curr.GetNumWordsCovered());
  }

  // case4: otherwise => return 2(nbWordBetween(S,S')+length(S))
  IFVERBOSE(4) std::cerr<< "MQ07disto:case4" << std::endl;
  return (float) -2*((int)curr.GetNumWordsBetween(prev) + (int)curr.GetNumWordsCovered());

}
예제 #9
0
 double combine(const Hypothesis & a, const Hypothesis & b, Hypothesis & ret) const {
   ret.hook = a.hook;
   ret.right_side = b.right_side;
   for (int i=0;i<a.prev_hyp.size();i++) {
     ret.prev_hyp.push_back(a.prev_hyp[i]);
   }
   ret.prev_hyp.push_back(b.id());
   return 0.0;
 }
예제 #10
0
std::map<size_t, const Factor*> GetPlaceholders(const Hypothesis &hypo, FactorType placeholderFactor)
{
  const InputPath &inputPath = hypo.GetTranslationOption().GetInputPath();
  const Phrase &inputPhrase = inputPath.GetPhrase();

  std::map<size_t, const Factor*> ret;

  for (size_t sourcePos = 0; sourcePos < inputPhrase.GetSize(); ++sourcePos) {
    const Factor *factor = inputPhrase.GetFactor(sourcePos, placeholderFactor);
    if (factor) {
      std::set<size_t> targetPos = hypo.GetTranslationOption().GetTargetPhrase().GetAlignTerm().GetAlignmentsForSource(sourcePos);
      CHECK(targetPos.size() == 1);
      ret[*targetPos.begin()] = factor;
    }
  }

  return ret;
}
ControlRecombinationState::ControlRecombinationState(const Hypothesis &hypo, const ControlRecombination &ff)
  :m_ff(ff)
{
  if (ff.GetType() == SameOutput) {
    //UTIL_THROW(util::Exception, "Implemented not yet completed for phrase-based model. Need to take into account the coverage");
    hypo.GetOutputPhrase(m_outputPhrase);
  } else {
    m_hypo = &hypo;
  }
}
void PhraseBasedReorderingState::Expand(const ManagerBase &mgr,
    const LexicalReordering &ff, const Hypothesis &hypo, size_t phraseTableInd,
    Scores &scores, FFState &state) const
{
  if ((m_direction != LRModel::Forward) || !m_first) {
    LRModel const& lrmodel = m_configuration;
    Range const &cur = hypo.GetInputPath().range;
    LRModel::ReorderingType reoType = (
        m_first ?
            lrmodel.GetOrientation(cur) :
            lrmodel.GetOrientation(prevPath->range, cur));
    CopyScores(mgr.system, scores, hypo.GetTargetPhrase(), reoType);
  }

  PhraseBasedReorderingState &stateCast =
      static_cast<PhraseBasedReorderingState&>(state);
  stateCast.Init(this, hypo.GetTargetPhrase(), hypo.GetInputPath(), false,
      NULL);
}
예제 #13
0
std::map<size_t, const Factor*>
Hypothesis::
GetPlaceholders(const Hypothesis &hypo, FactorType placeholderFactor) const
{
  const InputPath &inputPath = hypo.GetTranslationOption().GetInputPath();
  const Phrase &inputPhrase = inputPath.GetPhrase();

  std::map<size_t, const Factor*> ret;

  for (size_t sourcePos = 0; sourcePos < inputPhrase.GetSize(); ++sourcePos) {
    const Factor *factor = inputPhrase.GetFactor(sourcePos, placeholderFactor);
    if (factor) {
      std::set<size_t> targetPos = hypo.GetTranslationOption().GetTargetPhrase().GetAlignTerm().GetAlignmentsForSource(sourcePos);
      UTIL_THROW_IF2(targetPos.size() != 1,
                     "Placeholder should be aligned to 1, and only 1, word");
      ret[*targetPos.begin()] = factor;
    }
  }

  return ret;
}
예제 #14
0
bool WorldModelROS::hypothesisToMsg(const Hypothesis& hyp, wire_msgs::WorldState& msg) const {
    ros::Time time = ros::Time::now();

    msg.header.frame_id = world_model_frame_id_;
    msg.header.stamp = time;

    for(list<SemanticObject*>::const_iterator it = hyp.getObjects().begin(); it != hyp.getObjects().end(); ++it) {

        SemanticObject* obj_clone = (*it)->clone();
        obj_clone->propagate(time.toSec());

        wire_msgs::ObjectState obj_msg;
        if (objectToMsg(*obj_clone, obj_msg)) {
            msg.objects.push_back(obj_msg);
        }

        delete obj_clone;

    }

    return true;
}
예제 #15
0
파일: LM.cpp 프로젝트: arvs/mosesdecoder
size_t LM::Evaluate(
  const Hypothesis& hypo,
  size_t prevState,
  Scores &scores) const
{
  if (m_order <= 1) {
    return 0; // not sure if returning NULL is correct
  }

  if (hypo.targetPhrase.GetSize() == 0) {
    return 0; // not sure if returning NULL is correct
  }

    PhraseVec m_phraseVec(m_order);

  const size_t currEndPos = hypo.targetRange.endPos;
  const size_t startPos = hypo.targetRange.startPos;

  size_t index = 0;
  for (int currPos = (int) startPos - (int) m_order + 1 ; currPos <= (int) startPos ; currPos++) {
    if (currPos >= 0)
      m_phraseVec[index++] = &hypo.GetWord(currPos);
    else {
      m_phraseVec[index++] = &m_bos;
    }
  }

  SCORE lmScore = GetValueCache(m_phraseVec);

  // main loop
  size_t endPos = std::min(startPos + m_order - 2
                           , currEndPos);
  for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++) {
    // shift all args down 1 place
    for (size_t i = 0 ; i < m_order - 1 ; i++)
      m_phraseVec[i] = m_phraseVec[i + 1];

    // add last factor
    m_phraseVec.back() = &hypo.GetWord(currPos);

    lmScore	+= GetValueCache(m_phraseVec);
  }

  // end of sentence
  if (hypo.GetCoverage().IsComplete()) {
    const size_t size = hypo.GetSize();
    m_phraseVec.back() = &m_eos;

    for (size_t i = 0 ; i < m_order - 1 ; i ++) {
      int currPos = (int)(size - m_order + i + 1);
      if (currPos < 0)
        m_phraseVec[i] = &m_bos;
      else
        m_phraseVec[i] = &hypo.GetWord((size_t)currPos);
    }
    lmScore += GetValueCache(m_phraseVec);
  } else {
    if (endPos < currEndPos) {
      //need to get the LM state (otherwise the last LM state is fine)
      for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {
        for (size_t i = 0 ; i < m_order - 1 ; i++)
          m_phraseVec[i] = m_phraseVec[i + 1];
        m_phraseVec.back() = &hypo.GetWord(currPos);
      }
    }
  }

  size_t state = GetLastState();
  return state;
}
void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComponentCollection* accumulator) const
{
  const Sentence& input = *(m_local->input);
  const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase();

  for(size_t targetIndex = 0; targetIndex < targetPhrase.GetSize(); targetIndex++ ) {
    StringPiece targetString = targetPhrase.GetWord(targetIndex).GetString(0); // TODO: change for other factors

    if (m_ignorePunctuation) {
      // check if first char is punctuation
      char firstChar = targetString[0];
      CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
      if(charIterator != m_punctuationHash.end())
        continue;
    }

    if (m_biasFeature) {
      stringstream feature;
      feature << "glm_";
      feature << targetString;
      feature << "~";
      feature << "**BIAS**";
      accumulator->SparsePlusEquals(feature.str(), 1);
    }

    boost::unordered_set<uint64_t> alreadyScored;
    for(size_t sourceIndex = 0; sourceIndex < input.GetSize(); sourceIndex++ ) {
      const StringPiece sourceString = input.GetWord(sourceIndex).GetString(0);
      // TODO: change for other factors

      if (m_ignorePunctuation) {
        // check if first char is punctuation
        char firstChar = sourceString[0];
        CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
        if(charIterator != m_punctuationHash.end())
          continue;
      }
      const uint64_t sourceHash = util::MurmurHashNative(sourceString.data(), sourceString.size());

      if ( alreadyScored.find(sourceHash) == alreadyScored.end()) {
        bool sourceExists, targetExists;
        if (!m_unrestricted) {
          sourceExists = FindStringPiece(m_vocabSource, sourceString ) != m_vocabSource.end();
          targetExists = FindStringPiece(m_vocabTarget, targetString) != m_vocabTarget.end();
        }

        // no feature if vocab is in use and both words are not in restricted vocabularies
        if (m_unrestricted || (sourceExists && targetExists)) {
          if (m_sourceContext) {
            if (sourceIndex == 0) {
              // add <s> trigger feature for source
              stringstream feature;
              feature << "glm_";
              feature << targetString;
              feature << "~";
              feature << "<s>,";
              feature << sourceString;
              accumulator->SparsePlusEquals(feature.str(), 1);
              alreadyScored.insert(sourceHash);
            }

            // add source words to the right of current source word as context
            for(int contextIndex = sourceIndex+1; contextIndex < input.GetSize(); contextIndex++ ) {
              StringPiece contextString = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors
              bool contextExists;
              if (!m_unrestricted)
                contextExists = FindStringPiece(m_vocabSource, contextString ) != m_vocabSource.end();

              if (m_unrestricted || contextExists) {
                stringstream feature;
                feature << "glm_";
                feature << targetString;
                feature << "~";
                feature << sourceString;
                feature << ",";
                feature << contextString;
                accumulator->SparsePlusEquals(feature.str(), 1);
                alreadyScored.insert(sourceHash);
              }
            }
          } else if (m_biphrase) {
            // --> look backwards for constructing context
            int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex;

            // 1) source-target pair, trigger source word (can be discont.) and adjacent target word (bigram)
            StringPiece targetContext;
            if (globalTargetIndex > 0)
              targetContext = cur_hypo.GetWord(globalTargetIndex-1).GetString(0); // TODO: change for other factors
            else
              targetContext = "<s>";

            if (sourceIndex == 0) {
              StringPiece sourceTrigger = "<s>";
              AddFeature(accumulator, sourceTrigger, sourceString,
                         targetContext, targetString);
            } else
              for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) {
                StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors
                bool sourceTriggerExists = false;
                if (!m_unrestricted)
                  sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();

                if (m_unrestricted || sourceTriggerExists)
                  AddFeature(accumulator, sourceTrigger, sourceString,
                             targetContext, targetString);
              }

            // 2) source-target pair, adjacent source word (bigram) and trigger target word (can be discont.)
            StringPiece sourceContext;
            if (sourceIndex-1 >= 0)
              sourceContext = input.GetWord(sourceIndex-1).GetString(0); // TODO: change for other factors
            else
              sourceContext = "<s>";

            if (globalTargetIndex == 0) {
              string targetTrigger = "<s>";
              AddFeature(accumulator, sourceContext, sourceString,
                         targetTrigger, targetString);
            } else
              for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
                StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors
                bool targetTriggerExists = false;
                if (!m_unrestricted)
                  targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end();

                if (m_unrestricted || targetTriggerExists)
                  AddFeature(accumulator, sourceContext, sourceString,
                             targetTrigger, targetString);
              }
          } else if (m_bitrigger) {
            // allow additional discont. triggers on both sides
            int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex;

            if (sourceIndex == 0) {
              StringPiece sourceTrigger = "<s>";
              bool sourceTriggerExists = true;

              if (globalTargetIndex == 0) {
                string targetTrigger = "<s>";
                bool targetTriggerExists = true;

                if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
                  AddFeature(accumulator, sourceTrigger, sourceString,
                             targetTrigger, targetString);
              } else {
                // iterate backwards over target
                for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
                  StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors
                  bool targetTriggerExists = false;
                  if (!m_unrestricted)
                    targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end();

                  if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
                    AddFeature(accumulator, sourceTrigger, sourceString,
                               targetTrigger, targetString);
                }
              }
            }
            // iterate over both source and target
            else {
              // iterate backwards over source
              for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) {
                StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); // TODO: change for other factors
                bool sourceTriggerExists = false;
                if (!m_unrestricted)
                  sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();

                if (globalTargetIndex == 0) {
                  string targetTrigger = "<s>";
                  bool targetTriggerExists = true;

                  if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
                    AddFeature(accumulator, sourceTrigger, sourceString,
                               targetTrigger, targetString);
                } else {
                  // iterate backwards over target
                  for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) {
                    StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); // TODO: change for other factors
                    bool targetTriggerExists = false;
                    if (!m_unrestricted)
                      targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end();

                    if (m_unrestricted || (sourceTriggerExists && targetTriggerExists))
                      AddFeature(accumulator, sourceTrigger, sourceString,
                                 targetTrigger, targetString);
                  }
                }
              }
            }
          } else {
            stringstream feature;
            feature << "glm_";
            feature << targetString;
            feature << "~";
            feature << sourceString;
            accumulator->SparsePlusEquals(feature.str(), 1);
            alreadyScored.insert(sourceHash);

          }
        }
      }
    }
  }
}
void WordTranslationFeature::EvaluateWhenApplied
(const Hypothesis& hypo,
 ScoreComponentCollection* accumulator) const
{
  const Sentence& input = static_cast<const Sentence&>(hypo.GetInput());
  const TranslationOption& transOpt = hypo.GetTranslationOption();
  const TargetPhrase& targetPhrase = hypo.GetCurrTargetPhrase();
  const AlignmentInfo &alignment = targetPhrase.GetAlignTerm();

  // process aligned words
  for (AlignmentInfo::const_iterator alignmentPoint = alignment.begin(); alignmentPoint != alignment.end(); alignmentPoint++) {
    const Phrase& sourcePhrase = transOpt.GetInputPath().GetPhrase();
    int sourceIndex = alignmentPoint->first;
    int targetIndex = alignmentPoint->second;
    Word ws = sourcePhrase.GetWord(sourceIndex);
    if (m_factorTypeSource == 0 && ws.IsNonTerminal()) continue;
    Word wt = targetPhrase.GetWord(targetIndex);
    if (m_factorTypeSource == 0 && wt.IsNonTerminal()) continue;
    StringPiece sourceWord = ws.GetFactor(m_factorTypeSource)->GetString();
    StringPiece targetWord = wt.GetFactor(m_factorTypeTarget)->GetString();
    if (m_ignorePunctuation) {
      // check if source or target are punctuation
      char firstChar = sourceWord[0];
      CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
      if(charIterator != m_punctuationHash.end())
        continue;
      firstChar = targetWord[0];
      charIterator = m_punctuationHash.find( firstChar );
      if(charIterator != m_punctuationHash.end())
        continue;
    }

    if (!m_unrestricted) {
      if (FindStringPiece(m_vocabSource, sourceWord) == m_vocabSource.end())
        sourceWord = "OTHER";
      if (FindStringPiece(m_vocabTarget, targetWord) == m_vocabTarget.end())
        targetWord = "OTHER";
    }

    if (m_simple) {
      // construct feature name
      stringstream featureName;
      featureName << m_description << "_";
      featureName << sourceWord;
      featureName << "~";
      featureName << targetWord;
      accumulator->SparsePlusEquals(featureName.str(), 1);
    }
    if (m_domainTrigger && !m_sourceContext) {
      const bool use_topicid = input.GetUseTopicId();
      const bool use_topicid_prob = input.GetUseTopicIdAndProb();
      if (use_topicid || use_topicid_prob) {
        if(use_topicid) {
          // use topicid as trigger
          const long topicid = input.GetTopicId();
          stringstream feature;
          feature << m_description << "_";
          if (topicid == -1)
            feature << "unk";
          else
            feature << topicid;

          feature << "_";
          feature << sourceWord;
          feature << "~";
          feature << targetWord;
          accumulator->SparsePlusEquals(feature.str(), 1);
        } else {
          // use topic probabilities
          const vector<string> &topicid_prob = *(input.GetTopicIdAndProb());
          if (atol(topicid_prob[0].c_str()) == -1) {
            stringstream feature;
            feature << m_description << "_unk_";
            feature << sourceWord;
            feature << "~";
            feature << targetWord;
            accumulator->SparsePlusEquals(feature.str(), 1);
          } else {
            for (size_t i=0; i+1 < topicid_prob.size(); i+=2) {
              stringstream feature;
              feature << m_description << "_";
              feature << topicid_prob[i];
              feature << "_";
              feature << sourceWord;
              feature << "~";
              feature << targetWord;
              accumulator->SparsePlusEquals(feature.str(), atof((topicid_prob[i+1]).c_str()));
            }
          }
        }
      } else {
        // range over domain trigger words (keywords)
        const long docid = input.GetDocumentId();
        for (boost::unordered_set<std::string>::const_iterator p = m_vocabDomain[docid].begin(); p != m_vocabDomain[docid].end(); ++p) {
          string sourceTrigger = *p;
          stringstream feature;
          feature << m_description << "_";
          feature << sourceTrigger;
          feature << "_";
          feature << sourceWord;
          feature << "~";
          feature << targetWord;
          accumulator->SparsePlusEquals(feature.str(), 1);
        }
      }
    }
    if (m_sourceContext) {
      size_t globalSourceIndex = hypo.GetTranslationOption().GetStartPos() + sourceIndex;
      if (!m_domainTrigger && globalSourceIndex == 0) {
        // add <s> trigger feature for source
        stringstream feature;
        feature << m_description << "_";
        feature << "<s>,";
        feature << sourceWord;
        feature << "~";
        feature << targetWord;
        accumulator->SparsePlusEquals(feature.str(), 1);
      }

      // range over source words to get context
      for(size_t contextIndex = 0; contextIndex < input.GetSize(); contextIndex++ ) {
        if (contextIndex == globalSourceIndex) continue;
        StringPiece sourceTrigger = input.GetWord(contextIndex).GetFactor(m_factorTypeSource)->GetString();
        if (m_ignorePunctuation) {
          // check if trigger is punctuation
          char firstChar = sourceTrigger[0];
          CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
          if(charIterator != m_punctuationHash.end())
            continue;
        }

        const long docid = input.GetDocumentId();
        bool sourceTriggerExists = false;
        if (m_domainTrigger)
          sourceTriggerExists = FindStringPiece(m_vocabDomain[docid], sourceTrigger ) != m_vocabDomain[docid].end();
        else if (!m_unrestricted)
          sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();

        if (m_domainTrigger) {
          if (sourceTriggerExists) {
            stringstream feature;
            feature << m_description << "_";
            feature << sourceTrigger;
            feature << "_";
            feature << sourceWord;
            feature << "~";
            feature << targetWord;
            accumulator->SparsePlusEquals(feature.str(), 1);
          }
        } else if (m_unrestricted || sourceTriggerExists) {
          stringstream feature;
          feature << m_description << "_";
          if (contextIndex < globalSourceIndex) {
            feature << sourceTrigger;
            feature << ",";
            feature << sourceWord;
          } else {
            feature << sourceWord;
            feature << ",";
            feature << sourceTrigger;
          }
          feature << "~";
          feature << targetWord;
          accumulator->SparsePlusEquals(feature.str(), 1);
        }
      }
    }
    if (m_targetContext) {
      throw runtime_error("Can't use target words outside current translation option in a stateless feature");
      /*
      size_t globalTargetIndex = cur_hypo.GetCurrTargetWordsRange().GetStartPos() + targetIndex;
      if (globalTargetIndex == 0) {
      	// add <s> trigger feature for source
      	stringstream feature;
      	feature << "wt_";
      	feature << sourceWord;
      	feature << "~";
      	feature << "<s>,";
      	feature << targetWord;
      	accumulator->SparsePlusEquals(feature.str(), 1);
      }

      // range over target words (up to current position) to get context
      for(size_t contextIndex = 0; contextIndex < globalTargetIndex; contextIndex++ ) {
      	string targetTrigger = cur_hypo.GetWord(contextIndex).GetFactor(m_factorTypeTarget)->GetString();
      	if (m_ignorePunctuation) {
      		// check if trigger is punctuation
      		char firstChar = targetTrigger.at(0);
      		CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
      		if(charIterator != m_punctuationHash.end())
      			continue;
      	}

      	bool targetTriggerExists = false;
      	if (!m_unrestricted)
      		targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end();

      	if (m_unrestricted || targetTriggerExists) {
      		stringstream feature;
      		feature << "wt_";
      		feature << sourceWord;
      		feature << "~";
      		feature << targetTrigger;
      		feature << ",";
      		feature << targetWord;
      		accumulator->SparsePlusEquals(feature.str(), 1);
      	}
      }*/
    }
  }
}
예제 #18
0
void PhrasePairFeature::EvaluateWhenApplied(
  const Hypothesis& hypo,
  ScoreComponentCollection* accumulator) const
{
  const TargetPhrase& target = hypo.GetCurrTargetPhrase();
  const Phrase& source = hypo.GetTranslationOption().GetInputPath().GetPhrase();
  if (m_simple) {
    ostringstream namestr;
    namestr << "pp_";
    namestr << source.GetWord(0).GetFactor(m_sourceFactorId)->GetString();
    for (size_t i = 1; i < source.GetSize(); ++i) {
      const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId);
      namestr << ",";
      namestr << sourceFactor->GetString();
    }
    namestr << "~";
    namestr << target.GetWord(0).GetFactor(m_targetFactorId)->GetString();
    for (size_t i = 1; i < target.GetSize(); ++i) {
      const Factor* targetFactor = target.GetWord(i).GetFactor(m_targetFactorId);
      namestr << ",";
      namestr << targetFactor->GetString();
    }

    accumulator->SparsePlusEquals(namestr.str(),1);
  }
  if (m_domainTrigger) {
    const Sentence& input = static_cast<const Sentence&>(hypo.GetInput());
    const bool use_topicid = input.GetUseTopicId();
    const bool use_topicid_prob = input.GetUseTopicIdAndProb();

    // compute pair
    ostringstream pair;
    pair << source.GetWord(0).GetFactor(m_sourceFactorId)->GetString();
    for (size_t i = 1; i < source.GetSize(); ++i) {
      const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId);
      pair << ",";
      pair << sourceFactor->GetString();
    }
    pair << "~";
    pair << target.GetWord(0).GetFactor(m_targetFactorId)->GetString();
    for (size_t i = 1; i < target.GetSize(); ++i) {
      const Factor* targetFactor = target.GetWord(i).GetFactor(m_targetFactorId);
      pair << ",";
      pair << targetFactor->GetString();
    }

    if (use_topicid || use_topicid_prob) {
      if(use_topicid) {
        // use topicid as trigger
        const long topicid = input.GetTopicId();
        stringstream feature;
        feature << "pp_";
        if (topicid == -1)
          feature << "unk";
        else
          feature << topicid;

        feature << "_";
        feature << pair.str();
        accumulator->SparsePlusEquals(feature.str(), 1);
      } else {
        // use topic probabilities
        const vector<string> &topicid_prob = *(input.GetTopicIdAndProb());
        if (atol(topicid_prob[0].c_str()) == -1) {
          stringstream feature;
          feature << "pp_unk_";
          feature << pair.str();
          accumulator->SparsePlusEquals(feature.str(), 1);
        } else {
          for (size_t i=0; i+1 < topicid_prob.size(); i+=2) {
            stringstream feature;
            feature << "pp_";
            feature << topicid_prob[i];
            feature << "_";
            feature << pair.str();
            accumulator->SparsePlusEquals(feature.str(), atof((topicid_prob[i+1]).c_str()));
          }
        }
      }
    } else {
      // range over domain trigger words
      const long docid = input.GetDocumentId();
      for (set<string>::const_iterator p = m_vocabDomain[docid].begin(); p != m_vocabDomain[docid].end(); ++p) {
        string sourceTrigger = *p;
        ostringstream namestr;
        namestr << "pp_";
        namestr << sourceTrigger;
        namestr << "_";
        namestr << pair.str();
        accumulator->SparsePlusEquals(namestr.str(),1);
      }
    }
  }
  if (m_sourceContext) {
    const Sentence& input = static_cast<const Sentence&>(hypo.GetInput());

    // range over source words to get context
    for(size_t contextIndex = 0; contextIndex < input.GetSize(); contextIndex++ ) {
      StringPiece sourceTrigger = input.GetWord(contextIndex).GetFactor(m_sourceFactorId)->GetString();
      if (m_ignorePunctuation) {
        // check if trigger is punctuation
        char firstChar = sourceTrigger[0];
        CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
        if(charIterator != m_punctuationHash.end())
          continue;
      }

      bool sourceTriggerExists = false;
      if (!m_unrestricted)
        sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();

      if (m_unrestricted || sourceTriggerExists) {
        ostringstream namestr;
        namestr << "pp_";
        namestr << sourceTrigger;
        namestr << "~";
        namestr << source.GetWord(0).GetFactor(m_sourceFactorId)->GetString();
        for (size_t i = 1; i < source.GetSize(); ++i) {
          const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId);
          namestr << ",";
          namestr << sourceFactor->GetString();
        }
        namestr << "~";
        namestr << target.GetWord(0).GetFactor(m_targetFactorId)->GetString();
        for (size_t i = 1; i < target.GetSize(); ++i) {
          const Factor* targetFactor = target.GetWord(i).GetFactor(m_targetFactorId);
          namestr << ",";
          namestr << targetFactor->GetString();
        }

        accumulator->SparsePlusEquals(namestr.str(),1);
      }
    }
  }
}
예제 #19
0
ConstrainedDecodingState::ConstrainedDecodingState(const Hypothesis &hypo)
{
  hypo.GetOutputPhrase(m_outputPhrase);
}
예제 #20
0
void Observations::setTarget(UnitObserver* o) {
	UnitObserverType type = o->getType();

	if (mPendingRequests.empty() && type != GLOBAL) {
		//nothing to do
		o->setIdle();
		//EventLog::GetInstance().logEvent("IDLE "+o->getName());
		return;
	}

	LOG_INFO("Setting new target for observer "+o->getName());

	//if we have the same number of observers as hypotheses, then don't move them from their target group after it is set
	if (mNumObservers >= mNumObservableGroups && o->hasTargetUnits()) {
		//see if we're requesting this hypothesis again
		std::deque<Hypothesis*>::iterator hit = std::find(mPendingRequests.begin(), mPendingRequests.end(), o->getHypothesis());
		if (hit != mPendingRequests.end()) {
			o->setTarget(o->getHypothesis(), o->getTargetGroup(), (float)o->getAltitude(), o->getTargetDuration());
			mPendingRequests.erase(hit);
		}
		//EventLog::GetInstance().logEvent("FULL "+o->getName());
		return;
	}

	switch (type) {
		case GLOBAL: {
			//move to the centre of the observed groups
			/*std::vector<dtCore::RefPtr<UnitGroup> > groups = theApp->getTargets()->groupUnits(5000, RED);
			if (groups.size() > 0) {
				UnitGroup* g = groups[0].get();
				osg::Vec3 centre = g->getCentre();
				float radius = g->getRadius();
				centre.z() += radius;
				o->setTarget(centre, 1);
			}*/
		}
		break;
		case SEQUENTIAL: {
			//nearest neighbour (+ priorities) from start position, and wait 5 seconds
			float dMax = theApp->getTerrain()->GetHorizontalSize() * sqrt(2); //size of map
			std::deque<Hypothesis*>::iterator it;
			float dMin = std::numeric_limits<float>::max();
			dtCore::RefPtr<UnitGroup> group;
			Hypothesis* hypothesis = NULL;
			for (it = mPendingRequests.begin(); it != mPendingRequests.end(); it++) {
				dtCore::RefPtr<UnitGroup> g = (*it)->getObservationGroup();
				float d = (g->getCentre() - o->getPosition()).length();
				//normalise
				d /= dMax;
				//add in hypothesis priorities
				float p = d + (*it)->getPriority();
				//
				LOG_INFO("Hypothesis group "+boost::lexical_cast<std::string>((*it)->getName())+": Priority "+boost::lexical_cast<std::string>(p)+" ("+boost::lexical_cast<std::string>(d)+" + "+boost::lexical_cast<std::string>((*it)->getPriority()))
				//std::cout << (*it)->getName() << ":d" << d <<",t"<<(*it)->getPriority()<<",p"<<p<<std::endl;
				if (p < dMin) {
					dMin = p;
					group = g;
					hypothesis = *it;
				}
			}
			//dtCore::RefPtr<UnitGroup> u = nearestNeighbour(o->getPosition(), mPendingRequests);
			//o->setTarget(u.get(), 500, 2);
			//TODO set duration and altitude properly
			o->setTarget(hypothesis, group.get(), 500, theApp->getObservationDuration());
			mPendingRequests.erase(std::find(mPendingRequests.begin(), mPendingRequests.end(), hypothesis));
		}
		break;
		case ROUND_ROBIN: {
			if (!mPendingRequests.empty()) {
				//EventLog::GetInstance().logEvent("OBST "+mPendingRequests.front()->getName());
				o->setTarget(mPendingRequests.front(), mPendingRequests.front()->getObservationGroup().get(), 500, theApp->getObservationDuration());
				mPendingRequests.pop_front();
			}
			/*std::queue<Hypothesis*>::iterator it = mPendingRequests.begin();
			if (it != mPendingRequests.end()) {
				mPendingRequests.erase(it);
			}*/
		}
		break;
		case THREAT: {
			//calculate resultant threat of observing each target

			//decide what target each observee is going to...
			//conf > 0.5. conf decays (or variance increases) without observations.
			//target certainty... take closest if conf < 0.5 or if variance = global variance

			//err, lets just take closest for now...
			osg::Vec3 oPos = theApp->mapPointToTerrain(o->getPosition());
			std::deque<Hypothesis*>::iterator it;
			dtCore::RefPtr<UnitGroup> group;
			Hypothesis* hypothesis = NULL;
			UnitGroup* predictedGroup = NULL;
			UnitGroup* closestGroup = NULL;
			float confidence = 0;
			float minLeeway = std::numeric_limits<float>::max();
			for (it = mPendingRequests.begin(); it != mPendingRequests.end(); it++) {
				dtCore::RefPtr<UnitGroup> g = (*it)->getObservationGroup();
				osg::Vec3 gPos = g->getCentreCurrent();
				//float gRadius = g->getRadius();
				std::vector<dtCore::RefPtr<UnitGroup> > targetGroups = (*it)->getTargetGroups();
				dtCore::RefPtr<UnitGroup> targetGroup;
				float dMin = std::numeric_limits<float>::max();
				//find closest target group for this observation group
				for (std::vector<dtCore::RefPtr<UnitGroup> >::iterator it2 = targetGroups.begin(); it2 != targetGroups.end(); it2++ ) {
					float d = (gPos - (*it2)->getCentreCurrent()).length();
					//TODO fix: subtract radii from centre to centre distance
					//d = d - gRadius - (*it2)->getRadius();
					if (d < 0) d = 0;
					if (d < dMin) {
						dMin = d;
						targetGroup = *it2;
					}
				}
				//calc threat (time = dMin / speed)
				//if targetgroup is the same as the highest confidence (>0.5) group then assume running away
				UnitGroup* predictedTarget = (*it)->getTargetGroup();
				float conf = (*it)->getConfidence();
				if (predictedTarget != NULL && predictedTarget == targetGroup.get() && conf > 0.5) {
					//dMin = v2*ds/(v2-v1)
					dMin = (g->getMaxSpeed() * dMin) / (g->getMaxSpeed() - targetGroup->getMaxSpeed());
				}
				float attackTime = dMin / g->getMaxSpeed();

				//find time for observer to travel to observation group
				float travelTime = (gPos - oPos).length() / o->getMaxSpeed();

				//calc leeway
				float leeway = attackTime - travelTime;

				std::ostringstream oss;
				oss << "OBSL " << o->getName() << " " << (*it)->getName() << " " << attackTime << " " << travelTime << " " << leeway;
				EventLog::GetInstance().logEvent(oss.str());

				//take lowest (non-negative) leeway
				if (leeway > 0 && leeway < minLeeway) {
					minLeeway = leeway;
					group = g;
					hypothesis = *it;
					predictedGroup = predictedTarget;
					closestGroup = targetGroup.get();
					confidence = conf;
				}
			}
			//if hypothesis has high confidence, but not at min leeway target, then see if more important target to view
			//i.e., make better use of our time as min leeway isn't very likely
			if (hypothesis && confidence > 0.5 && closestGroup != predictedGroup) {
				float minTime = minLeeway;
				for (it = mPendingRequests.begin(); it != mPendingRequests.end(); it++) {
					float conf = (*it)->getConfidence();
					if (conf <= 0) {
						dtCore::RefPtr<UnitGroup> g = (*it)->getObservationGroup();
						osg::Vec3 gPos = g->getCentreCurrent();
						float travelTime = (gPos - oPos).length() / o->getMaxSpeed();
						//time to get there, observe, and get back
						float totalTime = (travelTime * 2) + (*it)->getSimTime();
						if (totalTime < minTime) {
							hypothesis = *it;
							group = g;
						}
					}
				}
			}
			//we didn't find any group to look at (probably they were all predicted (inf))
			if (hypothesis == NULL) {
				//find the hypothesis that has longest time since last obs.
				double maxTime = 0;
				for (it = mPendingRequests.begin(); it != mPendingRequests.end(); it++) {
					double t = (*it)->getTimeDiff();
					if (t >= maxTime) {
						maxTime = t;
						hypothesis = *it;
					}
				}
				group = hypothesis->getObservationGroup();
			}

			o->setTarget(hypothesis, group.get(), 500, theApp->getObservationDuration());
			mPendingRequests.erase(std::find(mPendingRequests.begin(), mPendingRequests.end(), hypothesis));

			LOG_DEBUG("Setting target to "+hypothesis->getName())
			//not sure if one target will hog the observer...
		}
		break;
		/*case SCANNING: {
			//find nearest neighbour
			dtCore::RefPtr<UnitGroup> u = nearestNeighbour(o->getPosition(), mPendingRequests);
			//find point where it comes into view
			osg::Vec3 obsPos = getObservationPoint(o->getGroundPosition(), o->getViewRadius(), u->getCentre(), u->getRadius());

			float totalTime = 2;
			float timeLeft = (totalTime / 2) - 0.5; //allowing time to turn around and return
			while (timeLeft > 0) {
				//find next neighbour
				dtCore::RefPtr<UnitGroup> u2 = nearestNeighbour(obsPos, mPendingRequests);
				osg::Vec3 obsPos2 = getObservationPoint(obsPos, o->getViewRadius(), u2->getCentre(), u2->getRadius());
				float nextDist = (obsPos - obsPos2).length();

				//see if we can get to next nearest neighbour and back within 2 seconds
				float travelTime = getTimeToTravel(nextDist);
				if (travelTime < timeLeft) {
					//add
					o->addTarget(u.get(), 500, 2);
					mPendingRequests.erase(std::find(mPendingRequests.begin(), mPendingRequests.end(), u));
				}
				timeLeft -= travelTime;
			}
			//add the targets in reverse order to get back...

			o->setTarget(u.get(), 500, 2);
			mPendingRequests.erase(std::find(mPendingRequests.begin(), mPendingRequests.end(), u));
		}
		break;*/
		/*case GROUP: {
			//get group
			dtCore::RefPtr<UnitGroup> u = nearestNeighbour(o->getPosition(), mPendingRequests);
			UnitGroup* group = theApp->getTargets()->groupObservedUnits(RED,u.get(),250);
			//set target position
			o->setTarget(group, 300, 2);
			//remove observed units
			std::vector<Unit*> units = group->getUnits();
			for (unsigned int i=0; i<units.size(); i++) {
				for (std::vector<dtCore::RefPtr<UnitGroup> >::iterator it = mPendingRequests.begin(); it != mPendingRequests.end(); it++) {
					if ((*it)->isInGroup(units[i])) {
						mPendingRequests.erase(it);
						break;
					}
				}
			}
		}
		break;*/
		case FOLLOWING: {
			//if leader, find nearest neighbour
			/*dtCore::RefPtr<UnitGroup> u = nearestNeighbour(o->getPosition(), mPendingRequests);
			o->setTarget(itMax->get(), 500, 0);
			mPendingRequests.erase(std::find(mPendingRequests.begin(), mPendingRequests.end(), u));
			*///if follower get to leaders position 2 seconds later
		}
		break;
	}

}
예제 #21
0
/***
 * print surface factor only for the given phrase
 */
void OutputSurface(std::ostream &out, const Hypothesis &edge, const std::vector<FactorType> &outputFactorOrder,
                   char reportSegmentation, bool reportAllFactors)
{
  CHECK(outputFactorOrder.size() > 0);
  const TargetPhrase& phrase = edge.GetCurrTargetPhrase();
  bool markUnknown = StaticData::Instance().GetMarkUnknown();
  if (reportAllFactors == true) {
    out << phrase;
  } else {
    FactorType placeholderFactor = StaticData::Instance().GetPlaceholderFactor();

    std::map<size_t, const Factor*> placeholders;
    if (placeholderFactor != NOT_FOUND) {
      // creates map of target position -> factor for placeholders
      placeholders = GetPlaceholders(edge, placeholderFactor);
    }

    size_t size = phrase.GetSize();
    for (size_t pos = 0 ; pos < size ; pos++) {
      const Factor *factor = phrase.GetFactor(pos, outputFactorOrder[0]);

      if (placeholders.size()) {
        // do placeholders
        std::map<size_t, const Factor*>::const_iterator iter = placeholders.find(pos);
        if (iter != placeholders.end()) {
          factor = iter->second;
        }
      }

      CHECK(factor);

      //preface surface form with UNK if marking unknowns
      const Word &word = phrase.GetWord(pos);
      if(markUnknown && word.IsOOV()) {
        out << "UNK" << *factor;
      } else {
        out << *factor;
      }

      for (size_t i = 1 ; i < outputFactorOrder.size() ; i++) {
        const Factor *factor = phrase.GetFactor(pos, outputFactorOrder[i]);
        CHECK(factor);

        out << "|" << *factor;
      }
      out << " ";
    }
  }

  // trace ("report segmentation") option "-t" / "-tt"
  if (reportSegmentation > 0 && phrase.GetSize() > 0) {
    const WordsRange &sourceRange = edge.GetCurrSourceWordsRange();
    const int sourceStart = sourceRange.GetStartPos();
    const int sourceEnd = sourceRange.GetEndPos();
    out << "|" << sourceStart << "-" << sourceEnd;    // enriched "-tt"
    if (reportSegmentation == 2) {
      out << ",wa=";
      const AlignmentInfo &ai = edge.GetCurrTargetPhrase().GetAlignTerm();
      OutputAlignment(out, ai, 0, 0);
      out << ",total=";
      out << edge.GetScore() - edge.GetPrevHypo()->GetScore();
      out << ",";
      ScoreComponentCollection scoreBreakdown(edge.GetScoreBreakdown());
      scoreBreakdown.MinusEquals(edge.GetPrevHypo()->GetScoreBreakdown());
      OutputAllFeatureScores(scoreBreakdown, out);
    }
    out << "| ";
  }
}
예제 #22
0
파일: Decode.cpp 프로젝트: hznlp/pbmt
void 
SearchSpace::
beamSearch(PhraseTable& pt, Features& weight, int beamsize, int distLimit, int tlimit, bool debug)
{
	Hypothesis tmpHypo;
	_hypoHeaps[0].push_back(tmpHypo);
	Hypothesis& initHypo=_hypoHeaps[0][0];
	initHypo.coveredWords.init(_sentence.size(),false);
	initHypo.currentScore=0;
	initHypo.estimatedScore=0;
	initHypo.features.init();
	initHypo.lastCoveredWord=-1;
	initHypo.represent="";
	initHypo.translation.clear();
	initHypo.trace.p_prev=NULL;
	initHypo.trace.p_rule=NULL;

	for(size_t heapIter=0;heapIter<_sentence.size();heapIter++)
	{
		HypothesisHeap& hHeap=_hypoHeaps[heapIter];
		hHeap.sortAndPrune(beamsize);
		for(int hypoIter=0;hypoIter<(int)hHeap.size();hypoIter++)
		{	
			Hypothesis& curHypo=hHeap[hypoIter];
			if(debug)
			{	
				cout<<"CurHypo ::"<<endl;
				curHypo.print(cout);
			}
			BITVECTOR bvec=curHypo.coveredWords;
			int firstUnCovered=bvec.firstFalse();
			for(int start=0;start<(int)bvec.size();start++)
			{
				if(bvec[start]==true)
					continue;
				for(int stop=start;
                                        stop<(int)bvec.size()&&
                                        bvec[stop]==false&&
                                        (stop<distLimit+firstUnCovered||start==firstUnCovered);
                                        stop++)
				{
					pair<int,int> phraseSpan=make_pair(start,stop);
					string candiPhrase=_auxSpace.queryPhrase(phraseSpan);
					vector<PhraseRuleEntry*>* p_rules=pt.queryRulesVec(candiPhrase);
					if(p_rules==NULL)continue;
					int newLength=heapIter+stop-start+1;
					BITVECTOR newBVec;
					combine(bvec,phraseSpan,newBVec);
					//cout<<"newBVec: "<<newBVec.represent()<<endl;
					double futureScore=_auxSpace.queryFutureCost(newBVec);

					vector<PhraseRuleEntry*>& rules=*p_rules;
					for(size_t rIter=0;rIter<rules.size()&&(int)rIter<tlimit;rIter++)
					{
						PhraseRuleEntry& rule=*rules[rIter];
						Hypothesis newHypo;
						newHypo.genFromChild(curHypo,make_pair(start,stop),rule,weight,futureScore);
						if(debug)
						{
							cout<<"newHypo :: "<<endl;
							newHypo.print(cout);
						}
						_hypoHeaps[newLength].addHypothesis(newHypo);
					}
				}
			}
		}
	}
	_hypoHeaps.back().sortAndPrune(beamsize);
	if(debug)
	{
		cout<<"final hHeap"<<endl;
		for(size_t i=0;i<_hypoHeaps.back().size();i++){
			Hypothesis& hypo=_hypoHeaps.back()[i];
			hypo.print(cout);
		}
	}
}
예제 #23
0
bool TranslationOption::Overlap(const Hypothesis &hypothesis) const
{
	const WordsBitmap &bitmap = hypothesis.GetWordsBitmap();
	return bitmap.Overlap(GetSourceWordsRange());
}
예제 #24
0
void KENLM<Model>::EvaluateWhenApplied(const ManagerBase &mgr,
                                       const Hypothesis &hypo, const FFState &prevState, Scores &scores,
                                       FFState &state) const
{
  KenLMState &stateCast = static_cast<KenLMState&>(state);

  const System &system = mgr.system;

  const lm::ngram::State &in_state =
    static_cast<const KenLMState&>(prevState).state;

  if (!hypo.GetTargetPhrase().GetSize()) {
    stateCast.state = in_state;
    return;
  }

  const std::size_t begin = hypo.GetCurrTargetWordsRange().GetStartPos();
  //[begin, end) in STL-like fashion.
  const std::size_t end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1;
  const std::size_t adjust_end = std::min(end, begin + m_ngram->Order() - 1);

  std::size_t position = begin;
  typename Model::State aux_state;
  typename Model::State *state0 = &stateCast.state, *state1 = &aux_state;

  float score = m_ngram->Score(in_state, TranslateID(hypo.GetWord(position)),
                               *state0);
  ++position;
  for (; position < adjust_end; ++position) {
    score += m_ngram->Score(*state0, TranslateID(hypo.GetWord(position)),
                            *state1);
    std::swap(state0, state1);
  }

  if (hypo.GetBitmap().IsComplete()) {
    // Score end of sentence.
    std::vector<lm::WordIndex> indices(m_ngram->Order() - 1);
    const lm::WordIndex *last = LastIDs(hypo, &indices.front());
    score += m_ngram->FullScoreForgotState(&indices.front(), last,
                                           m_ngram->GetVocabulary().EndSentence(), stateCast.state).prob;
  } else if (adjust_end < end) {
    // Get state after adding a long phrase.
    std::vector<lm::WordIndex> indices(m_ngram->Order() - 1);
    const lm::WordIndex *last = LastIDs(hypo, &indices.front());
    m_ngram->GetState(&indices.front(), last, stateCast.state);
  } else if (state0 != &stateCast.state) {
    // Short enough phrase that we can just reuse the state.
    stateCast.state = *state0;
  }

  score = TransformLMScore(score);

  bool OOVFeatureEnabled = false;
  if (OOVFeatureEnabled) {
    std::vector<float> scoresVec(2);
    scoresVec[0] = score;
    scoresVec[1] = 0.0;
    scores.PlusEquals(system, *this, scoresVec);
  } else {
    scores.PlusEquals(system, *this, score);
  }
}