Пример #1
0
void UnknownWordPenalty::ProcessXML(
		const Manager &mgr,
		MemPool &pool,
		const Sentence &sentence,
		InputPaths &inputPaths) const
{
	const Vector<const InputType::XMLOption*> &xmlOptions = sentence.GetXMLOptions();
	BOOST_FOREACH(const InputType::XMLOption *xmlOption, xmlOptions) {
		TargetPhraseImpl *target = TargetPhraseImpl::CreateFromString(pool, *this, mgr.system, xmlOption->GetTranslation());

    if (xmlOption->prob) {
      Scores &scores = target->GetScores();
      scores.PlusEquals(mgr.system, *this, Moses2::TransformScore(xmlOption->prob));
    }

    InputPath *path = inputPaths.GetMatrix().GetValue(xmlOption->startPos, xmlOption->phraseSize - 1);
    const SubPhrase<Moses2::Word> &source = path->subPhrase;

    mgr.system.featureFunctions.EvaluateInIsolation(pool, mgr.system, source, *target);

    TargetPhrases *tps = new (pool.Allocate<TargetPhrases>()) TargetPhrases(pool, 1);

    tps->AddTargetPhrase(*target);
    mgr.system.featureFunctions.EvaluateAfterTablePruning(pool, *tps, source);

    path->AddTargetPhrases(*this, tps);
	}
bool PhraseDictionaryTransliteration::SatisfyBackoff(const InputPath &inputPath) const
{
  const Phrase &sourcePhrase = inputPath.GetPhrase();

  assert(m_container);
  const DecodeGraph *decodeGraph = m_container->GetContainer();
  size_t backoff = decodeGraph->GetBackoff();

  if (backoff == 0) {
	  // ie. don't backoff. Collect ALL translations
	  return true;
  }

  if (sourcePhrase.GetSize() > backoff) {
	  // source phrase too big
	  return false;
  }

  // lookup translation only if no other translations
  InputPath::TargetPhrases::const_iterator iter;
  for (iter = inputPath.GetTargetPhrases().begin(); iter != inputPath.GetTargetPhrases().end(); ++iter) {
  	const std::pair<const TargetPhraseCollection*, const void*> &temp = iter->second;
  	const TargetPhraseCollection *tpCollPrev = temp.first;

  	if (tpCollPrev && tpCollPrev->GetSize()) {
  		// already have translation from another pt. Don't create translations
  		return false;
  	}
  }

  return true;
}
void PhraseDictionaryTransliteration::GetTargetPhraseCollection(InputPath &inputPath) const
{
    const Phrase &sourcePhrase = inputPath.GetPhrase();
    size_t hash = hash_value(sourcePhrase);

    CacheColl &cache = GetCache();

    std::map<size_t, std::pair<const TargetPhraseCollection*, clock_t> >::iterator iter;
    iter = cache.find(hash);

    if (iter != cache.end()) {
    	// already in cache
    	const TargetPhraseCollection *tpColl = iter->second.first;
    	inputPath.SetTargetPhrases(*this, tpColl, NULL);
    }
    else {
        // TRANSLITERATE
    	char *ptr = tmpnam(NULL);
    	string inFile(ptr);
    	ptr = tmpnam(NULL);
    	string outDir(ptr);

    	ofstream inStream(inFile.c_str());
    	inStream << sourcePhrase.ToString() << endl;
    	inStream.close();

    	string cmd = m_scriptDir + "/Transliteration/prepare-transliteration-phrase-table.pl" +
    			" --transliteration-model-dir " + m_filePath +
    			" --moses-src-dir " + m_mosesDir +
    			" --external-bin-dir " + m_externalDir +
    			" --input-extension " + m_inputLang +
    			" --output-extension " + m_outputLang +
    			" --oov-file " + inFile +
    			" --out-dir " + outDir;

    	int ret = system(cmd.c_str());
    	UTIL_THROW_IF2(ret != 0, "Transliteration script error");

    	TargetPhraseCollection *tpColl = new TargetPhraseCollection();
    	vector<TargetPhrase*> targetPhrases = CreateTargetPhrases(sourcePhrase, outDir);
    	vector<TargetPhrase*>::const_iterator iter;
    	for (iter = targetPhrases.begin(); iter != targetPhrases.end(); ++iter) {
    		TargetPhrase *tp = *iter;
    		tpColl->Add(tp);
    	}

    	std::pair<const TargetPhraseCollection*, clock_t> value(tpColl, clock());
    	cache[hash] = value;

    	inputPath.SetTargetPhrases(*this, tpColl, NULL);

    	// clean up temporary files
    	remove(inFile.c_str());

    	cmd = "rm -rf " + outDir;
    	system(cmd.c_str());
    }
}
Пример #4
0
	BOOST_FOREACH(InputPathBase *pathBase, inputPaths){
	  InputPath *path = static_cast<InputPath*>(pathBase);

	  if (SatisfyBackoff(mgr, *path)) {
		  const SubPhrase<Moses2::Word> &phrase = path->subPhrase;

		  TargetPhrases *tps = Lookup(mgr, mgr.GetPool(), *path);
		  path->AddTargetPhrases(*this, tps);
	  }
	}
// assumes that source-side syntax labels are stored in the target non-terminal field of the rules
void SourceGHKMTreeInputMatchFeature::EvaluateWithSourceContext(const InputType &input
    , const InputPath &inputPath
    , const TargetPhrase &targetPhrase
    , const StackVec *stackVec
    , ScoreComponentCollection &scoreBreakdown
    , ScoreComponentCollection *estimatedScores) const
{
  const Range& range = inputPath.GetWordsRange();
  size_t startPos = range.GetStartPos();
  size_t endPos = range.GetEndPos();
  const TreeInput& treeInput = static_cast<const TreeInput&>(input);
  const NonTerminalSet& treeInputLabels = treeInput.GetLabelSet(startPos,endPos);
  const Word& lhsLabel = targetPhrase.GetTargetLHS();

  const StaticData& staticData = StaticData::Instance();
  const Word& outputDefaultNonTerminal = staticData.GetOutputDefaultNonTerminal();

  std::vector<float> newScores(m_numScoreComponents,0.0); // m_numScoreComponents == 2 // first fires for matches, second for mismatches

  if ( (treeInputLabels.find(lhsLabel) != treeInputLabels.end()) && (lhsLabel != outputDefaultNonTerminal) ) {
    // match
    newScores[0] = 1.0;
  } else {
    // mismatch
    newScores[1] = 1.0;
  }

  scoreBreakdown.PlusEquals(this, newScores);
}
void PhraseDictionaryOnDisk::GetTargetPhraseCollectionBatch(InputPath &inputPath) const
{
  OnDiskPt::OnDiskWrapper &wrapper = const_cast<OnDiskPt::OnDiskWrapper&>(GetImplementation());
  const Phrase &phrase = inputPath.GetPhrase();
  const InputPath *prevInputPath = inputPath.GetPrevPath();

  const OnDiskPt::PhraseNode *prevPtNode = NULL;

  if (prevInputPath) {
    prevPtNode = static_cast<const OnDiskPt::PhraseNode*>(prevInputPath->GetPtNode(*this));
  } else {
    // Starting subphrase.
    assert(phrase.GetSize() == 1);
    prevPtNode = &wrapper.GetRootSourceNode();
  }

  // backoff
  if (!SatisfyBackoff(inputPath)) {
    return;
  }

  if (prevPtNode) {
    Word lastWord = phrase.GetWord(phrase.GetSize() - 1);
    lastWord.OnlyTheseFactors(m_inputFactors);
    OnDiskPt::Word *lastWordOnDisk = wrapper.ConvertFromMoses(m_input, lastWord);

    if (lastWordOnDisk == NULL) {
      // OOV according to this phrase table. Not possible to extend
      inputPath.SetTargetPhrases(*this, NULL, NULL);
    } else {
      const OnDiskPt::PhraseNode *ptNode = prevPtNode->GetChild(*lastWordOnDisk, wrapper);
      if (ptNode) {
        const TargetPhraseCollection *targetPhrases = GetTargetPhraseCollection(ptNode);
        inputPath.SetTargetPhrases(*this, targetPhrases, ptNode);
      } else {
        inputPath.SetTargetPhrases(*this, NULL, NULL);
      }

      delete lastWordOnDisk;
    }
  }
}
Пример #7
0
void Manager::CreateInputPaths(const InputPath &prevPath, size_t pos)
{
  if (pos >= m_sentence.GetSize()) {
    return;
  }

  Phrase *phrase = new Phrase(prevPath.GetPhrase(), 1);
  phrase->SetLastWord(m_sentence.GetWord(pos));

  InputPath *path = new InputPath(&prevPath, phrase, pos);
  m_inputPathQueue.push_back(path);

  CreateInputPaths(*path, pos + 1);
}
Пример #8
0
void InputFeature::Evaluate(const InputType &input
                            , const InputPath &inputPath
                            , ScoreComponentCollection &scoreBreakdown) const
{
	if (m_legacy) {
		//binary phrase-table does input feature itself
		return;
	}

  const ScorePair *scores = inputPath.GetInputScore();
  if (scores) {

  }
}
Пример #9
0
void InputFeature::EvaluateWithSourceContext(const InputType &input
    , const InputPath &inputPath
    , const TargetPhrase &targetPhrase
    , const StackVec *stackVec
    , ScoreComponentCollection &scoreBreakdown
    , ScoreComponentCollection *estimatedScores) const
{
  if (m_legacy) {
    //binary phrase-table does input feature itself
    return;
  } else if (input.GetType() == WordLatticeInput) {
    const ScorePair *scores = inputPath.GetInputScore();
    if (scores) {
      scoreBreakdown.PlusEquals(this, *scores);
    }
  }
}
void TranslationOptionCollection::SetInputScore(const InputPath &inputPath, PartialTranslOptColl &oldPtoc)
{
  const ScorePair *inputScore = inputPath.GetInputScore();
  if (inputScore == NULL) {
    return;
  }

  const InputFeature *inputFeature = StaticData::Instance().GetInputFeature();

  const std::vector<TranslationOption*> &transOpts = oldPtoc.GetList();
  for (size_t i = 0; i < transOpts.size(); ++i) {
    TranslationOption &transOpt = *transOpts[i];

    ScoreComponentCollection &scores = transOpt.GetScoreBreakdown();
    scores.PlusEquals(inputFeature, *inputScore);

  }
}
void ChartRuleLookupManagerMemoryPerSentence::GetChartRuleCollection(
  const InputPath &inputPath,
  size_t lastPos,
  ChartParserCallback &outColl)
{
  const Range &range = inputPath.GetWordsRange();
  size_t startPos = range.GetStartPos();
  size_t absEndPos = range.GetEndPos();

  m_lastPos = lastPos;
  m_stackVec.clear();
  m_stackScores.clear();
  m_outColl = &outColl;
  m_unaryPos = absEndPos-1; // rules ending in this position are unary and should not be added to collection

  // create/update data structure to quickly look up all chart cells that match start position and label.
  UpdateCompressedMatrix(startPos, absEndPos, lastPos);

  const PhraseDictionaryNodeMemory &rootNode = m_ruleTable.GetRootNode(GetParser().GetTranslationId());

  // all rules starting with terminal
  if (startPos == absEndPos) {
    GetTerminalExtension(&rootNode, startPos);
  }
  // all rules starting with nonterminal
  else if (absEndPos > startPos) {
    GetNonTerminalExtension(&rootNode, startPos);
  }

  // copy temporarily stored rules to out collection
  CompletedRuleCollection & rules = m_completedRules[absEndPos];
  for (vector<CompletedRule*>::const_iterator iter = rules.begin(); iter != rules.end(); ++iter) {
    outColl.Add((*iter)->GetTPC(), (*iter)->GetStackVector(), range);
  }

  rules.Clear();

}
/** constructor; just initialize the base class */
TranslationOptionCollectionLattice::TranslationOptionCollectionLattice(
  const WordLattice &input
  , size_t maxNoTransOptPerCoverage, float translationOptionThreshold)
  : TranslationOptionCollection(input, maxNoTransOptPerCoverage, translationOptionThreshold)
{
  UTIL_THROW_IF2(StaticData::Instance().GetUseLegacyPT(),
		  "Not for models using the legqacy binary phrase table");

  const InputFeature *inputFeature = StaticData::Instance().GetInputFeature();
  UTIL_THROW_IF2(inputFeature == NULL,
		  "Input feature must be specified");

  size_t maxPhraseLength = StaticData::Instance().GetMaxPhraseLength();
  size_t size = input.GetSize();

  // 1-word phrases
  for (size_t startPos = 0; startPos < size; ++startPos) {

    const std::vector<size_t> &nextNodes = input.GetNextNodes(startPos);

    WordsRange range(startPos, startPos);
    const NonTerminalSet &labels = input.GetLabelSet(startPos, startPos);

    const ConfusionNet::Column &col = input.GetColumn(startPos);
    for (size_t i = 0; i < col.size(); ++i) {
      const Word &word = col[i].first;
      UTIL_THROW_IF2(word.IsEpsilon(), "Epsilon not supported");

      Phrase subphrase;
      subphrase.AddWord(word);

      const ScorePair &scores = col[i].second;
      ScorePair *inputScore = new ScorePair(scores);

      InputPath *path = new InputPath(subphrase, labels, range, NULL, inputScore);

      size_t nextNode = nextNodes[i];
      path->SetNextNode(nextNode);

      m_inputPathQueue.push_back(path);
    }
  }

  // iteratively extend all paths
    for (size_t endPos = 1; endPos < size; ++endPos) {
      const std::vector<size_t> &nextNodes = input.GetNextNodes(endPos);

      // loop thru every previous paths
      size_t numPrevPaths = m_inputPathQueue.size();

      for (size_t i = 0; i < numPrevPaths; ++i) {
        //for (size_t pathInd = 0; pathInd < prevPaths.size(); ++pathInd) {
        const InputPath &prevPath = *m_inputPathQueue[i];

        size_t nextNode = prevPath.GetNextNode();
        if (prevPath.GetWordsRange().GetEndPos() + nextNode != endPos) {
        	continue;
        }

        size_t startPos = prevPath.GetWordsRange().GetStartPos();

        if (endPos - startPos + 1 > maxPhraseLength) {
        	continue;
        }

        WordsRange range(startPos, endPos);
        const NonTerminalSet &labels = input.GetLabelSet(startPos, endPos);

        const Phrase &prevPhrase = prevPath.GetPhrase();
        const ScorePair *prevInputScore = prevPath.GetInputScore();
        UTIL_THROW_IF2(prevInputScore == NULL,
        		"Null previous score");

        // loop thru every word at this position
        const ConfusionNet::Column &col = input.GetColumn(endPos);

        for (size_t i = 0; i < col.size(); ++i) {
          const Word &word = col[i].first;
          Phrase subphrase(prevPhrase);
          subphrase.AddWord(word);

          const ScorePair &scores = col[i].second;
          ScorePair *inputScore = new ScorePair(*prevInputScore);
          inputScore->PlusEquals(scores);

          InputPath *path = new InputPath(subphrase, labels, range, &prevPath, inputScore);

          size_t nextNode = nextNodes[i];
          path->SetNextNode(nextNode);

          m_inputPathQueue.push_back(path);
        } // for (size_t i = 0; i < col.size(); ++i) {

      } // for (size_t i = 0; i < numPrevPaths; ++i) {
    }
}
void TranslationOptionCollection::CreateTranslationOptionsForRange(
  const DecodeGraph &decodeGraph
  , size_t startPos
  , size_t endPos
  , bool adhereTableLimit
  , size_t graphInd
  , InputPath &inputPath)
{
  if ((StaticData::Instance().GetXmlInputType() != XmlExclusive) || !HasXmlOptionsOverlappingRange(startPos,endPos)) {

    // partial trans opt stored in here
    PartialTranslOptColl* oldPtoc = new PartialTranslOptColl;
    size_t totalEarlyPruned = 0;

    // initial translation step
    list <const DecodeStep* >::const_iterator iterStep = decodeGraph.begin();
    const DecodeStep &decodeStep = **iterStep;

    const PhraseDictionary &phraseDictionary = *decodeStep.GetPhraseDictionaryFeature();
    const TargetPhraseCollection *targetPhrases = inputPath.GetTargetPhrases(phraseDictionary);

    static_cast<const DecodeStepTranslation&>(decodeStep).ProcessInitialTranslation
    (m_source, *oldPtoc
     , startPos, endPos, adhereTableLimit
     , inputPath, targetPhrases);

    SetInputScore(inputPath, *oldPtoc);

    // do rest of decode steps
    int indexStep = 0;

    for (++iterStep ; iterStep != decodeGraph.end() ; ++iterStep) {

      const DecodeStep *decodeStep = *iterStep;
      PartialTranslOptColl* newPtoc = new PartialTranslOptColl;

      // go thru each intermediate trans opt just created
      const vector<TranslationOption*>& partTransOptList = oldPtoc->GetList();
      vector<TranslationOption*>::const_iterator iterPartialTranslOpt;
      for (iterPartialTranslOpt = partTransOptList.begin() ; iterPartialTranslOpt != partTransOptList.end() ; ++iterPartialTranslOpt) {
        TranslationOption &inputPartialTranslOpt = **iterPartialTranslOpt;

        if (const DecodeStepTranslation *translateStep = dynamic_cast<const DecodeStepTranslation*>(decodeStep) ) {
          const PhraseDictionary &phraseDictionary = *translateStep->GetPhraseDictionaryFeature();
          const TargetPhraseCollection *targetPhrases = inputPath.GetTargetPhrases(phraseDictionary);
          translateStep->Process(inputPartialTranslOpt
                                 , *decodeStep
                                 , *newPtoc
                                 , this
                                 , adhereTableLimit
                                 , targetPhrases);
        } else {
          const DecodeStepGeneration *genStep = dynamic_cast<const DecodeStepGeneration*>(decodeStep);
          assert(genStep);
          genStep->Process(inputPartialTranslOpt
                           , *decodeStep
                           , *newPtoc
                           , this
                           , adhereTableLimit);
        }
      }

      // last but 1 partial trans not required anymore
      totalEarlyPruned += newPtoc->GetPrunedCount();
      delete oldPtoc;
      oldPtoc = newPtoc;

      indexStep++;
    } // for (++iterStep

    // add to fully formed translation option list
    PartialTranslOptColl &lastPartialTranslOptColl	= *oldPtoc;
    const vector<TranslationOption*>& partTransOptList = lastPartialTranslOptColl.GetList();
    vector<TranslationOption*>::const_iterator iterColl;
    for (iterColl = partTransOptList.begin() ; iterColl != partTransOptList.end() ; ++iterColl) {
      TranslationOption *transOpt = *iterColl;
      if (StaticData::Instance().GetXmlInputType() != XmlConstraint || !ViolatesXmlOptionsConstraint(startPos,endPos,transOpt)) {
        Add(transOpt);
      }
    }

    lastPartialTranslOptColl.DetachAll();
    totalEarlyPruned += oldPtoc->GetPrunedCount();
    delete oldPtoc;
    // TRACE_ERR( "Early translation options pruned: " << totalEarlyPruned << endl);
  } // if ((StaticData::Instance().GetXmlInputType() != XmlExclusive) || !HasXmlOptionsOverlappingRange(startPos,endPos))

  if (graphInd == 0 && StaticData::Instance().GetXmlInputType() != XmlPassThrough && HasXmlOptionsOverlappingRange(startPos,endPos)) {
    CreateXmlOptionsForRange(startPos, endPos);
  }
}
void WordTranslationFeature::EvaluateWithSourceContext(const InputType &input
    , const InputPath &inputPath
    , const TargetPhrase &targetPhrase
    , const StackVec *stackVec
    , ScoreComponentCollection &scoreBreakdown
    , ScoreComponentCollection *estimatedScores) const
{
  const Sentence& sentence = static_cast<const Sentence&>(input);
  const AlignmentInfo &alignment = targetPhrase.GetAlignTerm();

  // process aligned words
  for (AlignmentInfo::const_iterator alignmentPoint = alignment.begin(); alignmentPoint != alignment.end(); alignmentPoint++) {
    const Phrase& sourcePhrase = inputPath.GetPhrase();
    int sourceIndex = alignmentPoint->first;
    int targetIndex = alignmentPoint->second;
    Word ws = sourcePhrase.GetWord(sourceIndex);
    if (m_factorTypeSource == 0 && ws.IsNonTerminal()) continue;
    Word wt = targetPhrase.GetWord(targetIndex);
    if (m_factorTypeSource == 0 && wt.IsNonTerminal()) continue;
    StringPiece sourceWord = ws.GetFactor(m_factorTypeSource)->GetString();
    StringPiece targetWord = wt.GetFactor(m_factorTypeTarget)->GetString();
    if (m_ignorePunctuation) {
      // check if source or target are punctuation
      char firstChar = sourceWord[0];
      CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
      if(charIterator != m_punctuationHash.end())
        continue;
      firstChar = targetWord[0];
      charIterator = m_punctuationHash.find( firstChar );
      if(charIterator != m_punctuationHash.end())
        continue;
    }

    if (!m_unrestricted) {
      if (FindStringPiece(m_vocabSource, sourceWord) == m_vocabSource.end())
        sourceWord = "OTHER";
      if (FindStringPiece(m_vocabTarget, targetWord) == m_vocabTarget.end())
        targetWord = "OTHER";
    }

    if (m_simple) {
      // construct feature name
      util::StringStream featureName;
      featureName << m_description << "_";
      featureName << sourceWord;
      featureName << "~";
      featureName << targetWord;
      scoreBreakdown.SparsePlusEquals(featureName.str(), 1);
    }
    if (m_domainTrigger && !m_sourceContext) {
      const bool use_topicid = sentence.GetUseTopicId();
      const bool use_topicid_prob = sentence.GetUseTopicIdAndProb();
      if (use_topicid || use_topicid_prob) {
        if(use_topicid) {
          // use topicid as trigger
          const long topicid = sentence.GetTopicId();
          util::StringStream feature;
          feature << m_description << "_";
          if (topicid == -1)
            feature << "unk";
          else
            feature << topicid;

          feature << "_";
          feature << sourceWord;
          feature << "~";
          feature << targetWord;
          scoreBreakdown.SparsePlusEquals(feature.str(), 1);
        } else {
          // use topic probabilities
          const vector<string> &topicid_prob = *(input.GetTopicIdAndProb());
          if (atol(topicid_prob[0].c_str()) == -1) {
            util::StringStream feature;
            feature << m_description << "_unk_";
            feature << sourceWord;
            feature << "~";
            feature << targetWord;
            scoreBreakdown.SparsePlusEquals(feature.str(), 1);
          } else {
            for (size_t i=0; i+1 < topicid_prob.size(); i+=2) {
              util::StringStream feature;
              feature << m_description << "_";
              feature << topicid_prob[i];
              feature << "_";
              feature << sourceWord;
              feature << "~";
              feature << targetWord;
              scoreBreakdown.SparsePlusEquals(feature.str(), atof((topicid_prob[i+1]).c_str()));
            }
          }
        }
      } else {
        // range over domain trigger words (keywords)
        const long docid = input.GetDocumentId();
        for (boost::unordered_set<std::string>::const_iterator p = m_vocabDomain[docid].begin(); p != m_vocabDomain[docid].end(); ++p) {
          string sourceTrigger = *p;
          util::StringStream feature;
          feature << m_description << "_";
          feature << sourceTrigger;
          feature << "_";
          feature << sourceWord;
          feature << "~";
          feature << targetWord;
          scoreBreakdown.SparsePlusEquals(feature.str(), 1);
        }
      }
    }
    if (m_sourceContext) {
      size_t globalSourceIndex = inputPath.GetWordsRange().GetStartPos() + sourceIndex;
      if (!m_domainTrigger && globalSourceIndex == 0) {
        // add <s> trigger feature for source
        util::StringStream feature;
        feature << m_description << "_";
        feature << "<s>,";
        feature << sourceWord;
        feature << "~";
        feature << targetWord;
        scoreBreakdown.SparsePlusEquals(feature.str(), 1);
      }

      // range over source words to get context
      for(size_t contextIndex = 0; contextIndex < input.GetSize(); contextIndex++ ) {
        if (contextIndex == globalSourceIndex) continue;
        StringPiece sourceTrigger = input.GetWord(contextIndex).GetFactor(m_factorTypeSource)->GetString();
        if (m_ignorePunctuation) {
          // check if trigger is punctuation
          char firstChar = sourceTrigger[0];
          CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
          if(charIterator != m_punctuationHash.end())
            continue;
        }

        const long docid = input.GetDocumentId();
        bool sourceTriggerExists = false;
        if (m_domainTrigger)
          sourceTriggerExists = FindStringPiece(m_vocabDomain[docid], sourceTrigger ) != m_vocabDomain[docid].end();
        else if (!m_unrestricted)
          sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();

        if (m_domainTrigger) {
          if (sourceTriggerExists) {
            util::StringStream feature;
            feature << m_description << "_";
            feature << sourceTrigger;
            feature << "_";
            feature << sourceWord;
            feature << "~";
            feature << targetWord;
            scoreBreakdown.SparsePlusEquals(feature.str(), 1);
          }
        } else if (m_unrestricted || sourceTriggerExists) {
          util::StringStream feature;
          feature << m_description << "_";
          if (contextIndex < globalSourceIndex) {
            feature << sourceTrigger;
            feature << ",";
            feature << sourceWord;
          } else {
            feature << sourceWord;
            feature << ",";
            feature << sourceTrigger;
          }
          feature << "~";
          feature << targetWord;
          scoreBreakdown.SparsePlusEquals(feature.str(), 1);
        }
      }
    }
    if (m_targetContext) {
      throw runtime_error("Can't use target words outside current translation option in a stateless feature");
      /*
      size_t globalTargetIndex = cur_hypo.GetCurrTargetWordsRange().GetStartPos() + targetIndex;
      if (globalTargetIndex == 0) {
      	// add <s> trigger feature for source
      	stringstream feature;
      	feature << "wt_";
      	feature << sourceWord;
      	feature << "~";
      	feature << "<s>,";
      	feature << targetWord;
      	accumulator->SparsePlusEquals(feature.str(), 1);
      }

      // range over target words (up to current position) to get context
      for(size_t contextIndex = 0; contextIndex < globalTargetIndex; contextIndex++ ) {
      	string targetTrigger = cur_hypo.GetWord(contextIndex).GetFactor(m_factorTypeTarget)->GetString();
      	if (m_ignorePunctuation) {
      		// check if trigger is punctuation
      		char firstChar = targetTrigger.at(0);
      		CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
      		if(charIterator != m_punctuationHash.end())
      			continue;
      	}

      	bool targetTriggerExists = false;
      	if (!m_unrestricted)
      		targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end();

      	if (m_unrestricted || targetTriggerExists) {
      		stringstream feature;
      		feature << "wt_";
      		feature << sourceWord;
      		feature << "~";
      		feature << targetTrigger;
      		feature << ",";
      		feature << targetWord;
      		accumulator->SparsePlusEquals(feature.str(), 1);
      	}
      }*/
    }
  }
}
void ChartRuleLookupManagerOnDisk::GetChartRuleCollection(
  const InputPath &inputPath,
  size_t lastPos,
  ChartParserCallback &outColl)
{
  const StaticData &staticData = StaticData::Instance();
  const Word &defaultSourceNonTerm = staticData.GetInputDefaultNonTerminal();
  const WordsRange &range = inputPath.GetWordsRange();

  size_t relEndPos = range.GetEndPos() - range.GetStartPos();
  size_t absEndPos = range.GetEndPos();

  // MAIN LOOP. create list of nodes of target phrases
  DottedRuleStackOnDisk &expandableDottedRuleList = *m_expandableDottedRuleListVec[range.GetStartPos()];

  // sort save nodes so only do nodes with most counts
  expandableDottedRuleList.SortSavedNodes();

  const DottedRuleStackOnDisk::SavedNodeColl &savedNodeColl = expandableDottedRuleList.GetSavedNodeColl();
  //cerr << "savedNodeColl=" << savedNodeColl.size() << " ";

  const ChartCellLabel &sourceWordLabel = GetSourceAt(absEndPos);

  for (size_t ind = 0; ind < (savedNodeColl.size()) ; ++ind) {
    const SavedNodeOnDisk &savedNode = *savedNodeColl[ind];

    const DottedRuleOnDisk &prevDottedRule = savedNode.GetDottedRule();
    const OnDiskPt::PhraseNode &prevNode = prevDottedRule.GetLastNode();
    size_t startPos = prevDottedRule.IsRoot() ? range.GetStartPos() : prevDottedRule.GetWordsRange().GetEndPos() + 1;

    // search for terminal symbol
    if (startPos == absEndPos) {
      OnDiskPt::Word *sourceWordBerkeleyDb = m_dbWrapper.ConvertFromMoses(m_inputFactorsVec, sourceWordLabel.GetLabel());

      if (sourceWordBerkeleyDb != NULL) {
        const OnDiskPt::PhraseNode *node = prevNode.GetChild(*sourceWordBerkeleyDb, m_dbWrapper);
        if (node != NULL) {
          // TODO figure out why source word is needed from node, not from sentence
          // prob to do with factors or non-term
          //const Word &sourceWord = node->GetSourceWord();
          DottedRuleOnDisk *dottedRule = new DottedRuleOnDisk(*node, sourceWordLabel, prevDottedRule);
          expandableDottedRuleList.Add(relEndPos+1, dottedRule);

          // cache for cleanup
          m_sourcePhraseNode.push_back(node);
        }

        delete sourceWordBerkeleyDb;
      }
    }

    // search for non-terminals
    size_t endPos, stackInd;
    if (startPos > absEndPos)
      continue;
    else if (startPos == range.GetStartPos() && range.GetEndPos() > range.GetStartPos()) {
      // start.
      endPos = absEndPos - 1;
      stackInd = relEndPos;
    } else {
      endPos = absEndPos;
      stackInd = relEndPos + 1;
    }

    // get target nonterminals in this span from chart
    const ChartCellLabelSet &chartNonTermSet =
      GetTargetLabelSet(startPos, endPos);

    //const Word &defaultSourceNonTerm = staticData.GetInputDefaultNonTerminal()
    //                                   ,&defaultTargetNonTerm = staticData.GetOutputDefaultNonTerminal();

    // go through each SOURCE lhs
    const NonTerminalSet &sourceLHSSet = GetParser().GetInputPath(startPos, endPos).GetNonTerminalSet();

    NonTerminalSet::const_iterator iterSourceLHS;
    for (iterSourceLHS = sourceLHSSet.begin(); iterSourceLHS != sourceLHSSet.end(); ++iterSourceLHS) {
      const Word &sourceLHS = *iterSourceLHS;

      OnDiskPt::Word *sourceLHSBerkeleyDb = m_dbWrapper.ConvertFromMoses(m_inputFactorsVec, sourceLHS);

      if (sourceLHSBerkeleyDb == NULL) {
        delete sourceLHSBerkeleyDb;
        continue; // vocab not in pt. node definately won't be in there
      }

      const OnDiskPt::PhraseNode *sourceNode = prevNode.GetChild(*sourceLHSBerkeleyDb, m_dbWrapper);
      delete sourceLHSBerkeleyDb;

      if (sourceNode == NULL)
        continue; // didn't find source node

      // go through each TARGET lhs
      ChartCellLabelSet::const_iterator iterChartNonTerm;
      for (iterChartNonTerm = chartNonTermSet.begin(); iterChartNonTerm != chartNonTermSet.end(); ++iterChartNonTerm) {
        if (*iterChartNonTerm == NULL) {
          continue;
        }
        const ChartCellLabel &cellLabel = **iterChartNonTerm;

        bool doSearch = true;
        if (m_dictionary.m_maxSpanDefault != NOT_FOUND) {
          // for Hieu's source syntax

          bool isSourceSyntaxNonTerm = sourceLHS != defaultSourceNonTerm;
          size_t nonTermNumWordsCovered = endPos - startPos + 1;

          doSearch = isSourceSyntaxNonTerm ?
                     nonTermNumWordsCovered <=  m_dictionary.m_maxSpanLabelled :
                     nonTermNumWordsCovered <= m_dictionary.m_maxSpanDefault;

        }

        if (doSearch) {

          OnDiskPt::Word *chartNonTermBerkeleyDb = m_dbWrapper.ConvertFromMoses(m_outputFactorsVec, cellLabel.GetLabel());

          if (chartNonTermBerkeleyDb == NULL)
            continue;

          const OnDiskPt::PhraseNode *node = sourceNode->GetChild(*chartNonTermBerkeleyDb, m_dbWrapper);
          delete chartNonTermBerkeleyDb;

          if (node == NULL)
            continue;

          // found matching entry
          //const Word &sourceWord = node->GetSourceWord();
          DottedRuleOnDisk *dottedRule = new DottedRuleOnDisk(*node, cellLabel, prevDottedRule);
          expandableDottedRuleList.Add(stackInd, dottedRule);

          m_sourcePhraseNode.push_back(node);
        }
      } // for (iterChartNonTerm

      delete sourceNode;

    } // for (iterLabelListf

    // return list of target phrases
    DottedRuleCollOnDisk &nodes = expandableDottedRuleList.Get(relEndPos + 1);

    // source LHS
    DottedRuleCollOnDisk::const_iterator iterDottedRuleColl;
    for (iterDottedRuleColl = nodes.begin(); iterDottedRuleColl != nodes.end(); ++iterDottedRuleColl) {
      // node of last source word
      const DottedRuleOnDisk &prevDottedRule = **iterDottedRuleColl;
      if (prevDottedRule.Done())
        continue;
      prevDottedRule.Done(true);

      const OnDiskPt::PhraseNode &prevNode = prevDottedRule.GetLastNode();

      //get node for each source LHS
      const NonTerminalSet &lhsSet = GetParser().GetInputPath(range.GetStartPos(), range.GetEndPos()).GetNonTerminalSet();
      NonTerminalSet::const_iterator iterLabelSet;
      for (iterLabelSet = lhsSet.begin(); iterLabelSet != lhsSet.end(); ++iterLabelSet) {
        const Word &sourceLHS = *iterLabelSet;

        OnDiskPt::Word *sourceLHSBerkeleyDb = m_dbWrapper.ConvertFromMoses(m_inputFactorsVec, sourceLHS);
        if (sourceLHSBerkeleyDb == NULL)
          continue;

        const TargetPhraseCollection *targetPhraseCollection = NULL;
        const OnDiskPt::PhraseNode *node = prevNode.GetChild(*sourceLHSBerkeleyDb, m_dbWrapper);
        if (node) {
          uint64_t tpCollFilePos = node->GetValue();
          std::map<uint64_t, const TargetPhraseCollection*>::const_iterator iterCache = m_cache.find(tpCollFilePos);
          if (iterCache == m_cache.end()) {

            const OnDiskPt::TargetPhraseCollection *tpcollBerkeleyDb = node->GetTargetPhraseCollection(m_dictionary.GetTableLimit(), m_dbWrapper);

            std::vector<float> weightT = staticData.GetWeights(&m_dictionary);
            targetPhraseCollection
            = tpcollBerkeleyDb->ConvertToMoses(m_inputFactorsVec
                                               ,m_outputFactorsVec
                                               ,m_dictionary
                                               ,weightT
                                               ,m_dbWrapper.GetVocab()
                                               ,true);

            delete tpcollBerkeleyDb;
            m_cache[tpCollFilePos] = targetPhraseCollection;
          } else {
            // just get out of cache
            targetPhraseCollection = iterCache->second;
          }

          UTIL_THROW_IF2(targetPhraseCollection == NULL, "Error");
          if (!targetPhraseCollection->IsEmpty()) {
            AddCompletedRule(prevDottedRule, *targetPhraseCollection,
                             range, outColl);
          }

        } // if (node)

        delete node;
        delete sourceLHSBerkeleyDb;
      }
    }
  } // for (size_t ind = 0; ind < savedNodeColl.size(); ++ind)

  //cerr << numDerivations << " ";
}
Пример #16
0
void PhrasePairFeature::EvaluateWithSourceContext(const InputType &input
    , const InputPath &inputPath
    , const TargetPhrase &targetPhrase
    , const StackVec *stackVec
    , ScoreComponentCollection &scoreBreakdown
    , ScoreComponentCollection *estimatedFutureScore) const
{
  const Phrase& source = inputPath.GetPhrase();
  if (m_simple) {
    ostringstream namestr;
    namestr << "pp_";
    namestr << source.GetWord(0).GetFactor(m_sourceFactorId)->GetString();
    for (size_t i = 1; i < source.GetSize(); ++i) {
      const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId);
      namestr << ",";
      namestr << sourceFactor->GetString();
    }
    namestr << "~";
    namestr << targetPhrase.GetWord(0).GetFactor(m_targetFactorId)->GetString();
    for (size_t i = 1; i < targetPhrase.GetSize(); ++i) {
      const Factor* targetFactor = targetPhrase.GetWord(i).GetFactor(m_targetFactorId);
      namestr << ",";
      namestr << targetFactor->GetString();
    }

    scoreBreakdown.SparsePlusEquals(namestr.str(),1);
  }
  if (m_domainTrigger) {
    const Sentence& isnt = static_cast<const Sentence&>(input);
    const bool use_topicid = isnt.GetUseTopicId();
    const bool use_topicid_prob = isnt.GetUseTopicIdAndProb();

    // compute pair
    ostringstream pair;
    pair << source.GetWord(0).GetFactor(m_sourceFactorId)->GetString();
    for (size_t i = 1; i < source.GetSize(); ++i) {
      const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId);
      pair << ",";
      pair << sourceFactor->GetString();
    }
    pair << "~";
    pair << targetPhrase.GetWord(0).GetFactor(m_targetFactorId)->GetString();
    for (size_t i = 1; i < targetPhrase.GetSize(); ++i) {
      const Factor* targetFactor = targetPhrase.GetWord(i).GetFactor(m_targetFactorId);
      pair << ",";
      pair << targetFactor->GetString();
    }

    if (use_topicid || use_topicid_prob) {
      if(use_topicid) {
        // use topicid as trigger
        const long topicid = isnt.GetTopicId();
        stringstream feature;
        feature << "pp_";
        if (topicid == -1)
          feature << "unk";
        else
          feature << topicid;

        feature << "_";
        feature << pair.str();
        scoreBreakdown.SparsePlusEquals(feature.str(), 1);
      } else {
        // use topic probabilities
        const vector<string> &topicid_prob = *(isnt.GetTopicIdAndProb());
        if (atol(topicid_prob[0].c_str()) == -1) {
          stringstream feature;
          feature << "pp_unk_";
          feature << pair.str();
          scoreBreakdown.SparsePlusEquals(feature.str(), 1);
        } else {
          for (size_t i=0; i+1 < topicid_prob.size(); i+=2) {
            stringstream feature;
            feature << "pp_";
            feature << topicid_prob[i];
            feature << "_";
            feature << pair.str();
            scoreBreakdown.SparsePlusEquals(feature.str(), atof((topicid_prob[i+1]).c_str()));
          }
        }
      }
    } else {
      // range over domain trigger words
      const long docid = isnt.GetDocumentId();
      for (set<string>::const_iterator p = m_vocabDomain[docid].begin(); p != m_vocabDomain[docid].end(); ++p) {
        string sourceTrigger = *p;
        ostringstream namestr;
        namestr << "pp_";
        namestr << sourceTrigger;
        namestr << "_";
        namestr << pair.str();
        scoreBreakdown.SparsePlusEquals(namestr.str(),1);
      }
    }
  }
  if (m_sourceContext) {
    const Sentence& isnt = static_cast<const Sentence&>(input);

    // range over source words to get context
    for(size_t contextIndex = 0; contextIndex < isnt.GetSize(); contextIndex++ ) {
      StringPiece sourceTrigger = isnt.GetWord(contextIndex).GetFactor(m_sourceFactorId)->GetString();
      if (m_ignorePunctuation) {
        // check if trigger is punctuation
        char firstChar = sourceTrigger[0];
        CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
        if(charIterator != m_punctuationHash.end())
          continue;
      }

      bool sourceTriggerExists = false;
      if (!m_unrestricted)
        sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();

      if (m_unrestricted || sourceTriggerExists) {
        ostringstream namestr;
        namestr << "pp_";
        namestr << sourceTrigger;
        namestr << "~";
        namestr << source.GetWord(0).GetFactor(m_sourceFactorId)->GetString();
        for (size_t i = 1; i < source.GetSize(); ++i) {
          const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId);
          namestr << ",";
          namestr << sourceFactor->GetString();
        }
        namestr << "~";
        namestr << targetPhrase.GetWord(0).GetFactor(m_targetFactorId)->GetString();
        for (size_t i = 1; i < targetPhrase.GetSize(); ++i) {
          const Factor* targetFactor = targetPhrase.GetWord(i).GetFactor(m_targetFactorId);
          namestr << ",";
          namestr << targetFactor->GetString();
        }

        scoreBreakdown.SparsePlusEquals(namestr.str(),1);
      }
    }
  }
}