void UnknownWordPenalty::ProcessXML( const Manager &mgr, MemPool &pool, const Sentence &sentence, InputPaths &inputPaths) const { const Vector<const InputType::XMLOption*> &xmlOptions = sentence.GetXMLOptions(); BOOST_FOREACH(const InputType::XMLOption *xmlOption, xmlOptions) { TargetPhraseImpl *target = TargetPhraseImpl::CreateFromString(pool, *this, mgr.system, xmlOption->GetTranslation()); if (xmlOption->prob) { Scores &scores = target->GetScores(); scores.PlusEquals(mgr.system, *this, Moses2::TransformScore(xmlOption->prob)); } InputPath *path = inputPaths.GetMatrix().GetValue(xmlOption->startPos, xmlOption->phraseSize - 1); const SubPhrase<Moses2::Word> &source = path->subPhrase; mgr.system.featureFunctions.EvaluateInIsolation(pool, mgr.system, source, *target); TargetPhrases *tps = new (pool.Allocate<TargetPhrases>()) TargetPhrases(pool, 1); tps->AddTargetPhrase(*target); mgr.system.featureFunctions.EvaluateAfterTablePruning(pool, *tps, source); path->AddTargetPhrases(*this, tps); }
bool PhraseDictionaryTransliteration::SatisfyBackoff(const InputPath &inputPath) const { const Phrase &sourcePhrase = inputPath.GetPhrase(); assert(m_container); const DecodeGraph *decodeGraph = m_container->GetContainer(); size_t backoff = decodeGraph->GetBackoff(); if (backoff == 0) { // ie. don't backoff. Collect ALL translations return true; } if (sourcePhrase.GetSize() > backoff) { // source phrase too big return false; } // lookup translation only if no other translations InputPath::TargetPhrases::const_iterator iter; for (iter = inputPath.GetTargetPhrases().begin(); iter != inputPath.GetTargetPhrases().end(); ++iter) { const std::pair<const TargetPhraseCollection*, const void*> &temp = iter->second; const TargetPhraseCollection *tpCollPrev = temp.first; if (tpCollPrev && tpCollPrev->GetSize()) { // already have translation from another pt. Don't create translations return false; } } return true; }
void PhraseDictionaryTransliteration::GetTargetPhraseCollection(InputPath &inputPath) const { const Phrase &sourcePhrase = inputPath.GetPhrase(); size_t hash = hash_value(sourcePhrase); CacheColl &cache = GetCache(); std::map<size_t, std::pair<const TargetPhraseCollection*, clock_t> >::iterator iter; iter = cache.find(hash); if (iter != cache.end()) { // already in cache const TargetPhraseCollection *tpColl = iter->second.first; inputPath.SetTargetPhrases(*this, tpColl, NULL); } else { // TRANSLITERATE char *ptr = tmpnam(NULL); string inFile(ptr); ptr = tmpnam(NULL); string outDir(ptr); ofstream inStream(inFile.c_str()); inStream << sourcePhrase.ToString() << endl; inStream.close(); string cmd = m_scriptDir + "/Transliteration/prepare-transliteration-phrase-table.pl" + " --transliteration-model-dir " + m_filePath + " --moses-src-dir " + m_mosesDir + " --external-bin-dir " + m_externalDir + " --input-extension " + m_inputLang + " --output-extension " + m_outputLang + " --oov-file " + inFile + " --out-dir " + outDir; int ret = system(cmd.c_str()); UTIL_THROW_IF2(ret != 0, "Transliteration script error"); TargetPhraseCollection *tpColl = new TargetPhraseCollection(); vector<TargetPhrase*> targetPhrases = CreateTargetPhrases(sourcePhrase, outDir); vector<TargetPhrase*>::const_iterator iter; for (iter = targetPhrases.begin(); iter != targetPhrases.end(); ++iter) { TargetPhrase *tp = *iter; tpColl->Add(tp); } std::pair<const TargetPhraseCollection*, clock_t> value(tpColl, clock()); cache[hash] = value; inputPath.SetTargetPhrases(*this, tpColl, NULL); // clean up temporary files remove(inFile.c_str()); cmd = "rm -rf " + outDir; system(cmd.c_str()); } }
BOOST_FOREACH(InputPathBase *pathBase, inputPaths){ InputPath *path = static_cast<InputPath*>(pathBase); if (SatisfyBackoff(mgr, *path)) { const SubPhrase<Moses2::Word> &phrase = path->subPhrase; TargetPhrases *tps = Lookup(mgr, mgr.GetPool(), *path); path->AddTargetPhrases(*this, tps); } }
// assumes that source-side syntax labels are stored in the target non-terminal field of the rules void SourceGHKMTreeInputMatchFeature::EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedScores) const { const Range& range = inputPath.GetWordsRange(); size_t startPos = range.GetStartPos(); size_t endPos = range.GetEndPos(); const TreeInput& treeInput = static_cast<const TreeInput&>(input); const NonTerminalSet& treeInputLabels = treeInput.GetLabelSet(startPos,endPos); const Word& lhsLabel = targetPhrase.GetTargetLHS(); const StaticData& staticData = StaticData::Instance(); const Word& outputDefaultNonTerminal = staticData.GetOutputDefaultNonTerminal(); std::vector<float> newScores(m_numScoreComponents,0.0); // m_numScoreComponents == 2 // first fires for matches, second for mismatches if ( (treeInputLabels.find(lhsLabel) != treeInputLabels.end()) && (lhsLabel != outputDefaultNonTerminal) ) { // match newScores[0] = 1.0; } else { // mismatch newScores[1] = 1.0; } scoreBreakdown.PlusEquals(this, newScores); }
void PhraseDictionaryOnDisk::GetTargetPhraseCollectionBatch(InputPath &inputPath) const { OnDiskPt::OnDiskWrapper &wrapper = const_cast<OnDiskPt::OnDiskWrapper&>(GetImplementation()); const Phrase &phrase = inputPath.GetPhrase(); const InputPath *prevInputPath = inputPath.GetPrevPath(); const OnDiskPt::PhraseNode *prevPtNode = NULL; if (prevInputPath) { prevPtNode = static_cast<const OnDiskPt::PhraseNode*>(prevInputPath->GetPtNode(*this)); } else { // Starting subphrase. assert(phrase.GetSize() == 1); prevPtNode = &wrapper.GetRootSourceNode(); } // backoff if (!SatisfyBackoff(inputPath)) { return; } if (prevPtNode) { Word lastWord = phrase.GetWord(phrase.GetSize() - 1); lastWord.OnlyTheseFactors(m_inputFactors); OnDiskPt::Word *lastWordOnDisk = wrapper.ConvertFromMoses(m_input, lastWord); if (lastWordOnDisk == NULL) { // OOV according to this phrase table. Not possible to extend inputPath.SetTargetPhrases(*this, NULL, NULL); } else { const OnDiskPt::PhraseNode *ptNode = prevPtNode->GetChild(*lastWordOnDisk, wrapper); if (ptNode) { const TargetPhraseCollection *targetPhrases = GetTargetPhraseCollection(ptNode); inputPath.SetTargetPhrases(*this, targetPhrases, ptNode); } else { inputPath.SetTargetPhrases(*this, NULL, NULL); } delete lastWordOnDisk; } } }
void Manager::CreateInputPaths(const InputPath &prevPath, size_t pos) { if (pos >= m_sentence.GetSize()) { return; } Phrase *phrase = new Phrase(prevPath.GetPhrase(), 1); phrase->SetLastWord(m_sentence.GetWord(pos)); InputPath *path = new InputPath(&prevPath, phrase, pos); m_inputPathQueue.push_back(path); CreateInputPaths(*path, pos + 1); }
void InputFeature::Evaluate(const InputType &input , const InputPath &inputPath , ScoreComponentCollection &scoreBreakdown) const { if (m_legacy) { //binary phrase-table does input feature itself return; } const ScorePair *scores = inputPath.GetInputScore(); if (scores) { } }
void InputFeature::EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedScores) const { if (m_legacy) { //binary phrase-table does input feature itself return; } else if (input.GetType() == WordLatticeInput) { const ScorePair *scores = inputPath.GetInputScore(); if (scores) { scoreBreakdown.PlusEquals(this, *scores); } } }
void TranslationOptionCollection::SetInputScore(const InputPath &inputPath, PartialTranslOptColl &oldPtoc) { const ScorePair *inputScore = inputPath.GetInputScore(); if (inputScore == NULL) { return; } const InputFeature *inputFeature = StaticData::Instance().GetInputFeature(); const std::vector<TranslationOption*> &transOpts = oldPtoc.GetList(); for (size_t i = 0; i < transOpts.size(); ++i) { TranslationOption &transOpt = *transOpts[i]; ScoreComponentCollection &scores = transOpt.GetScoreBreakdown(); scores.PlusEquals(inputFeature, *inputScore); } }
void ChartRuleLookupManagerMemoryPerSentence::GetChartRuleCollection( const InputPath &inputPath, size_t lastPos, ChartParserCallback &outColl) { const Range &range = inputPath.GetWordsRange(); size_t startPos = range.GetStartPos(); size_t absEndPos = range.GetEndPos(); m_lastPos = lastPos; m_stackVec.clear(); m_stackScores.clear(); m_outColl = &outColl; m_unaryPos = absEndPos-1; // rules ending in this position are unary and should not be added to collection // create/update data structure to quickly look up all chart cells that match start position and label. UpdateCompressedMatrix(startPos, absEndPos, lastPos); const PhraseDictionaryNodeMemory &rootNode = m_ruleTable.GetRootNode(GetParser().GetTranslationId()); // all rules starting with terminal if (startPos == absEndPos) { GetTerminalExtension(&rootNode, startPos); } // all rules starting with nonterminal else if (absEndPos > startPos) { GetNonTerminalExtension(&rootNode, startPos); } // copy temporarily stored rules to out collection CompletedRuleCollection & rules = m_completedRules[absEndPos]; for (vector<CompletedRule*>::const_iterator iter = rules.begin(); iter != rules.end(); ++iter) { outColl.Add((*iter)->GetTPC(), (*iter)->GetStackVector(), range); } rules.Clear(); }
/** constructor; just initialize the base class */ TranslationOptionCollectionLattice::TranslationOptionCollectionLattice( const WordLattice &input , size_t maxNoTransOptPerCoverage, float translationOptionThreshold) : TranslationOptionCollection(input, maxNoTransOptPerCoverage, translationOptionThreshold) { UTIL_THROW_IF2(StaticData::Instance().GetUseLegacyPT(), "Not for models using the legqacy binary phrase table"); const InputFeature *inputFeature = StaticData::Instance().GetInputFeature(); UTIL_THROW_IF2(inputFeature == NULL, "Input feature must be specified"); size_t maxPhraseLength = StaticData::Instance().GetMaxPhraseLength(); size_t size = input.GetSize(); // 1-word phrases for (size_t startPos = 0; startPos < size; ++startPos) { const std::vector<size_t> &nextNodes = input.GetNextNodes(startPos); WordsRange range(startPos, startPos); const NonTerminalSet &labels = input.GetLabelSet(startPos, startPos); const ConfusionNet::Column &col = input.GetColumn(startPos); for (size_t i = 0; i < col.size(); ++i) { const Word &word = col[i].first; UTIL_THROW_IF2(word.IsEpsilon(), "Epsilon not supported"); Phrase subphrase; subphrase.AddWord(word); const ScorePair &scores = col[i].second; ScorePair *inputScore = new ScorePair(scores); InputPath *path = new InputPath(subphrase, labels, range, NULL, inputScore); size_t nextNode = nextNodes[i]; path->SetNextNode(nextNode); m_inputPathQueue.push_back(path); } } // iteratively extend all paths for (size_t endPos = 1; endPos < size; ++endPos) { const std::vector<size_t> &nextNodes = input.GetNextNodes(endPos); // loop thru every previous paths size_t numPrevPaths = m_inputPathQueue.size(); for (size_t i = 0; i < numPrevPaths; ++i) { //for (size_t pathInd = 0; pathInd < prevPaths.size(); ++pathInd) { const InputPath &prevPath = *m_inputPathQueue[i]; size_t nextNode = prevPath.GetNextNode(); if (prevPath.GetWordsRange().GetEndPos() + nextNode != endPos) { continue; } size_t startPos = prevPath.GetWordsRange().GetStartPos(); if (endPos - startPos + 1 > maxPhraseLength) { continue; } WordsRange range(startPos, endPos); const NonTerminalSet &labels = input.GetLabelSet(startPos, endPos); const Phrase &prevPhrase = prevPath.GetPhrase(); const ScorePair *prevInputScore = prevPath.GetInputScore(); UTIL_THROW_IF2(prevInputScore == NULL, "Null previous score"); // loop thru every word at this position const ConfusionNet::Column &col = input.GetColumn(endPos); for (size_t i = 0; i < col.size(); ++i) { const Word &word = col[i].first; Phrase subphrase(prevPhrase); subphrase.AddWord(word); const ScorePair &scores = col[i].second; ScorePair *inputScore = new ScorePair(*prevInputScore); inputScore->PlusEquals(scores); InputPath *path = new InputPath(subphrase, labels, range, &prevPath, inputScore); size_t nextNode = nextNodes[i]; path->SetNextNode(nextNode); m_inputPathQueue.push_back(path); } // for (size_t i = 0; i < col.size(); ++i) { } // for (size_t i = 0; i < numPrevPaths; ++i) { } }
void TranslationOptionCollection::CreateTranslationOptionsForRange( const DecodeGraph &decodeGraph , size_t startPos , size_t endPos , bool adhereTableLimit , size_t graphInd , InputPath &inputPath) { if ((StaticData::Instance().GetXmlInputType() != XmlExclusive) || !HasXmlOptionsOverlappingRange(startPos,endPos)) { // partial trans opt stored in here PartialTranslOptColl* oldPtoc = new PartialTranslOptColl; size_t totalEarlyPruned = 0; // initial translation step list <const DecodeStep* >::const_iterator iterStep = decodeGraph.begin(); const DecodeStep &decodeStep = **iterStep; const PhraseDictionary &phraseDictionary = *decodeStep.GetPhraseDictionaryFeature(); const TargetPhraseCollection *targetPhrases = inputPath.GetTargetPhrases(phraseDictionary); static_cast<const DecodeStepTranslation&>(decodeStep).ProcessInitialTranslation (m_source, *oldPtoc , startPos, endPos, adhereTableLimit , inputPath, targetPhrases); SetInputScore(inputPath, *oldPtoc); // do rest of decode steps int indexStep = 0; for (++iterStep ; iterStep != decodeGraph.end() ; ++iterStep) { const DecodeStep *decodeStep = *iterStep; PartialTranslOptColl* newPtoc = new PartialTranslOptColl; // go thru each intermediate trans opt just created const vector<TranslationOption*>& partTransOptList = oldPtoc->GetList(); vector<TranslationOption*>::const_iterator iterPartialTranslOpt; for (iterPartialTranslOpt = partTransOptList.begin() ; iterPartialTranslOpt != partTransOptList.end() ; ++iterPartialTranslOpt) { TranslationOption &inputPartialTranslOpt = **iterPartialTranslOpt; if (const DecodeStepTranslation *translateStep = dynamic_cast<const DecodeStepTranslation*>(decodeStep) ) { const PhraseDictionary &phraseDictionary = *translateStep->GetPhraseDictionaryFeature(); const TargetPhraseCollection *targetPhrases = inputPath.GetTargetPhrases(phraseDictionary); translateStep->Process(inputPartialTranslOpt , *decodeStep , *newPtoc , this , adhereTableLimit , targetPhrases); } else { const DecodeStepGeneration *genStep = dynamic_cast<const DecodeStepGeneration*>(decodeStep); assert(genStep); genStep->Process(inputPartialTranslOpt , *decodeStep , *newPtoc , this , adhereTableLimit); } } // last but 1 partial trans not required anymore totalEarlyPruned += newPtoc->GetPrunedCount(); delete oldPtoc; oldPtoc = newPtoc; indexStep++; } // for (++iterStep // add to fully formed translation option list PartialTranslOptColl &lastPartialTranslOptColl = *oldPtoc; const vector<TranslationOption*>& partTransOptList = lastPartialTranslOptColl.GetList(); vector<TranslationOption*>::const_iterator iterColl; for (iterColl = partTransOptList.begin() ; iterColl != partTransOptList.end() ; ++iterColl) { TranslationOption *transOpt = *iterColl; if (StaticData::Instance().GetXmlInputType() != XmlConstraint || !ViolatesXmlOptionsConstraint(startPos,endPos,transOpt)) { Add(transOpt); } } lastPartialTranslOptColl.DetachAll(); totalEarlyPruned += oldPtoc->GetPrunedCount(); delete oldPtoc; // TRACE_ERR( "Early translation options pruned: " << totalEarlyPruned << endl); } // if ((StaticData::Instance().GetXmlInputType() != XmlExclusive) || !HasXmlOptionsOverlappingRange(startPos,endPos)) if (graphInd == 0 && StaticData::Instance().GetXmlInputType() != XmlPassThrough && HasXmlOptionsOverlappingRange(startPos,endPos)) { CreateXmlOptionsForRange(startPos, endPos); } }
void WordTranslationFeature::EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedScores) const { const Sentence& sentence = static_cast<const Sentence&>(input); const AlignmentInfo &alignment = targetPhrase.GetAlignTerm(); // process aligned words for (AlignmentInfo::const_iterator alignmentPoint = alignment.begin(); alignmentPoint != alignment.end(); alignmentPoint++) { const Phrase& sourcePhrase = inputPath.GetPhrase(); int sourceIndex = alignmentPoint->first; int targetIndex = alignmentPoint->second; Word ws = sourcePhrase.GetWord(sourceIndex); if (m_factorTypeSource == 0 && ws.IsNonTerminal()) continue; Word wt = targetPhrase.GetWord(targetIndex); if (m_factorTypeSource == 0 && wt.IsNonTerminal()) continue; StringPiece sourceWord = ws.GetFactor(m_factorTypeSource)->GetString(); StringPiece targetWord = wt.GetFactor(m_factorTypeTarget)->GetString(); if (m_ignorePunctuation) { // check if source or target are punctuation char firstChar = sourceWord[0]; CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); if(charIterator != m_punctuationHash.end()) continue; firstChar = targetWord[0]; charIterator = m_punctuationHash.find( firstChar ); if(charIterator != m_punctuationHash.end()) continue; } if (!m_unrestricted) { if (FindStringPiece(m_vocabSource, sourceWord) == m_vocabSource.end()) sourceWord = "OTHER"; if (FindStringPiece(m_vocabTarget, targetWord) == m_vocabTarget.end()) targetWord = "OTHER"; } if (m_simple) { // construct feature name util::StringStream featureName; featureName << m_description << "_"; featureName << sourceWord; featureName << "~"; featureName << targetWord; scoreBreakdown.SparsePlusEquals(featureName.str(), 1); } if (m_domainTrigger && !m_sourceContext) { const bool use_topicid = sentence.GetUseTopicId(); const bool use_topicid_prob = sentence.GetUseTopicIdAndProb(); if (use_topicid || use_topicid_prob) { if(use_topicid) { // use topicid as trigger const long topicid = sentence.GetTopicId(); util::StringStream feature; feature << m_description << "_"; if (topicid == -1) feature << "unk"; else feature << topicid; feature << "_"; feature << sourceWord; feature << "~"; feature << targetWord; scoreBreakdown.SparsePlusEquals(feature.str(), 1); } else { // use topic probabilities const vector<string> &topicid_prob = *(input.GetTopicIdAndProb()); if (atol(topicid_prob[0].c_str()) == -1) { util::StringStream feature; feature << m_description << "_unk_"; feature << sourceWord; feature << "~"; feature << targetWord; scoreBreakdown.SparsePlusEquals(feature.str(), 1); } else { for (size_t i=0; i+1 < topicid_prob.size(); i+=2) { util::StringStream feature; feature << m_description << "_"; feature << topicid_prob[i]; feature << "_"; feature << sourceWord; feature << "~"; feature << targetWord; scoreBreakdown.SparsePlusEquals(feature.str(), atof((topicid_prob[i+1]).c_str())); } } } } else { // range over domain trigger words (keywords) const long docid = input.GetDocumentId(); for (boost::unordered_set<std::string>::const_iterator p = m_vocabDomain[docid].begin(); p != m_vocabDomain[docid].end(); ++p) { string sourceTrigger = *p; util::StringStream feature; feature << m_description << "_"; feature << sourceTrigger; feature << "_"; feature << sourceWord; feature << "~"; feature << targetWord; scoreBreakdown.SparsePlusEquals(feature.str(), 1); } } } if (m_sourceContext) { size_t globalSourceIndex = inputPath.GetWordsRange().GetStartPos() + sourceIndex; if (!m_domainTrigger && globalSourceIndex == 0) { // add <s> trigger feature for source util::StringStream feature; feature << m_description << "_"; feature << "<s>,"; feature << sourceWord; feature << "~"; feature << targetWord; scoreBreakdown.SparsePlusEquals(feature.str(), 1); } // range over source words to get context for(size_t contextIndex = 0; contextIndex < input.GetSize(); contextIndex++ ) { if (contextIndex == globalSourceIndex) continue; StringPiece sourceTrigger = input.GetWord(contextIndex).GetFactor(m_factorTypeSource)->GetString(); if (m_ignorePunctuation) { // check if trigger is punctuation char firstChar = sourceTrigger[0]; CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); if(charIterator != m_punctuationHash.end()) continue; } const long docid = input.GetDocumentId(); bool sourceTriggerExists = false; if (m_domainTrigger) sourceTriggerExists = FindStringPiece(m_vocabDomain[docid], sourceTrigger ) != m_vocabDomain[docid].end(); else if (!m_unrestricted) sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end(); if (m_domainTrigger) { if (sourceTriggerExists) { util::StringStream feature; feature << m_description << "_"; feature << sourceTrigger; feature << "_"; feature << sourceWord; feature << "~"; feature << targetWord; scoreBreakdown.SparsePlusEquals(feature.str(), 1); } } else if (m_unrestricted || sourceTriggerExists) { util::StringStream feature; feature << m_description << "_"; if (contextIndex < globalSourceIndex) { feature << sourceTrigger; feature << ","; feature << sourceWord; } else { feature << sourceWord; feature << ","; feature << sourceTrigger; } feature << "~"; feature << targetWord; scoreBreakdown.SparsePlusEquals(feature.str(), 1); } } } if (m_targetContext) { throw runtime_error("Can't use target words outside current translation option in a stateless feature"); /* size_t globalTargetIndex = cur_hypo.GetCurrTargetWordsRange().GetStartPos() + targetIndex; if (globalTargetIndex == 0) { // add <s> trigger feature for source stringstream feature; feature << "wt_"; feature << sourceWord; feature << "~"; feature << "<s>,"; feature << targetWord; accumulator->SparsePlusEquals(feature.str(), 1); } // range over target words (up to current position) to get context for(size_t contextIndex = 0; contextIndex < globalTargetIndex; contextIndex++ ) { string targetTrigger = cur_hypo.GetWord(contextIndex).GetFactor(m_factorTypeTarget)->GetString(); if (m_ignorePunctuation) { // check if trigger is punctuation char firstChar = targetTrigger.at(0); CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); if(charIterator != m_punctuationHash.end()) continue; } bool targetTriggerExists = false; if (!m_unrestricted) targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end(); if (m_unrestricted || targetTriggerExists) { stringstream feature; feature << "wt_"; feature << sourceWord; feature << "~"; feature << targetTrigger; feature << ","; feature << targetWord; accumulator->SparsePlusEquals(feature.str(), 1); } }*/ } } }
void ChartRuleLookupManagerOnDisk::GetChartRuleCollection( const InputPath &inputPath, size_t lastPos, ChartParserCallback &outColl) { const StaticData &staticData = StaticData::Instance(); const Word &defaultSourceNonTerm = staticData.GetInputDefaultNonTerminal(); const WordsRange &range = inputPath.GetWordsRange(); size_t relEndPos = range.GetEndPos() - range.GetStartPos(); size_t absEndPos = range.GetEndPos(); // MAIN LOOP. create list of nodes of target phrases DottedRuleStackOnDisk &expandableDottedRuleList = *m_expandableDottedRuleListVec[range.GetStartPos()]; // sort save nodes so only do nodes with most counts expandableDottedRuleList.SortSavedNodes(); const DottedRuleStackOnDisk::SavedNodeColl &savedNodeColl = expandableDottedRuleList.GetSavedNodeColl(); //cerr << "savedNodeColl=" << savedNodeColl.size() << " "; const ChartCellLabel &sourceWordLabel = GetSourceAt(absEndPos); for (size_t ind = 0; ind < (savedNodeColl.size()) ; ++ind) { const SavedNodeOnDisk &savedNode = *savedNodeColl[ind]; const DottedRuleOnDisk &prevDottedRule = savedNode.GetDottedRule(); const OnDiskPt::PhraseNode &prevNode = prevDottedRule.GetLastNode(); size_t startPos = prevDottedRule.IsRoot() ? range.GetStartPos() : prevDottedRule.GetWordsRange().GetEndPos() + 1; // search for terminal symbol if (startPos == absEndPos) { OnDiskPt::Word *sourceWordBerkeleyDb = m_dbWrapper.ConvertFromMoses(m_inputFactorsVec, sourceWordLabel.GetLabel()); if (sourceWordBerkeleyDb != NULL) { const OnDiskPt::PhraseNode *node = prevNode.GetChild(*sourceWordBerkeleyDb, m_dbWrapper); if (node != NULL) { // TODO figure out why source word is needed from node, not from sentence // prob to do with factors or non-term //const Word &sourceWord = node->GetSourceWord(); DottedRuleOnDisk *dottedRule = new DottedRuleOnDisk(*node, sourceWordLabel, prevDottedRule); expandableDottedRuleList.Add(relEndPos+1, dottedRule); // cache for cleanup m_sourcePhraseNode.push_back(node); } delete sourceWordBerkeleyDb; } } // search for non-terminals size_t endPos, stackInd; if (startPos > absEndPos) continue; else if (startPos == range.GetStartPos() && range.GetEndPos() > range.GetStartPos()) { // start. endPos = absEndPos - 1; stackInd = relEndPos; } else { endPos = absEndPos; stackInd = relEndPos + 1; } // get target nonterminals in this span from chart const ChartCellLabelSet &chartNonTermSet = GetTargetLabelSet(startPos, endPos); //const Word &defaultSourceNonTerm = staticData.GetInputDefaultNonTerminal() // ,&defaultTargetNonTerm = staticData.GetOutputDefaultNonTerminal(); // go through each SOURCE lhs const NonTerminalSet &sourceLHSSet = GetParser().GetInputPath(startPos, endPos).GetNonTerminalSet(); NonTerminalSet::const_iterator iterSourceLHS; for (iterSourceLHS = sourceLHSSet.begin(); iterSourceLHS != sourceLHSSet.end(); ++iterSourceLHS) { const Word &sourceLHS = *iterSourceLHS; OnDiskPt::Word *sourceLHSBerkeleyDb = m_dbWrapper.ConvertFromMoses(m_inputFactorsVec, sourceLHS); if (sourceLHSBerkeleyDb == NULL) { delete sourceLHSBerkeleyDb; continue; // vocab not in pt. node definately won't be in there } const OnDiskPt::PhraseNode *sourceNode = prevNode.GetChild(*sourceLHSBerkeleyDb, m_dbWrapper); delete sourceLHSBerkeleyDb; if (sourceNode == NULL) continue; // didn't find source node // go through each TARGET lhs ChartCellLabelSet::const_iterator iterChartNonTerm; for (iterChartNonTerm = chartNonTermSet.begin(); iterChartNonTerm != chartNonTermSet.end(); ++iterChartNonTerm) { if (*iterChartNonTerm == NULL) { continue; } const ChartCellLabel &cellLabel = **iterChartNonTerm; bool doSearch = true; if (m_dictionary.m_maxSpanDefault != NOT_FOUND) { // for Hieu's source syntax bool isSourceSyntaxNonTerm = sourceLHS != defaultSourceNonTerm; size_t nonTermNumWordsCovered = endPos - startPos + 1; doSearch = isSourceSyntaxNonTerm ? nonTermNumWordsCovered <= m_dictionary.m_maxSpanLabelled : nonTermNumWordsCovered <= m_dictionary.m_maxSpanDefault; } if (doSearch) { OnDiskPt::Word *chartNonTermBerkeleyDb = m_dbWrapper.ConvertFromMoses(m_outputFactorsVec, cellLabel.GetLabel()); if (chartNonTermBerkeleyDb == NULL) continue; const OnDiskPt::PhraseNode *node = sourceNode->GetChild(*chartNonTermBerkeleyDb, m_dbWrapper); delete chartNonTermBerkeleyDb; if (node == NULL) continue; // found matching entry //const Word &sourceWord = node->GetSourceWord(); DottedRuleOnDisk *dottedRule = new DottedRuleOnDisk(*node, cellLabel, prevDottedRule); expandableDottedRuleList.Add(stackInd, dottedRule); m_sourcePhraseNode.push_back(node); } } // for (iterChartNonTerm delete sourceNode; } // for (iterLabelListf // return list of target phrases DottedRuleCollOnDisk &nodes = expandableDottedRuleList.Get(relEndPos + 1); // source LHS DottedRuleCollOnDisk::const_iterator iterDottedRuleColl; for (iterDottedRuleColl = nodes.begin(); iterDottedRuleColl != nodes.end(); ++iterDottedRuleColl) { // node of last source word const DottedRuleOnDisk &prevDottedRule = **iterDottedRuleColl; if (prevDottedRule.Done()) continue; prevDottedRule.Done(true); const OnDiskPt::PhraseNode &prevNode = prevDottedRule.GetLastNode(); //get node for each source LHS const NonTerminalSet &lhsSet = GetParser().GetInputPath(range.GetStartPos(), range.GetEndPos()).GetNonTerminalSet(); NonTerminalSet::const_iterator iterLabelSet; for (iterLabelSet = lhsSet.begin(); iterLabelSet != lhsSet.end(); ++iterLabelSet) { const Word &sourceLHS = *iterLabelSet; OnDiskPt::Word *sourceLHSBerkeleyDb = m_dbWrapper.ConvertFromMoses(m_inputFactorsVec, sourceLHS); if (sourceLHSBerkeleyDb == NULL) continue; const TargetPhraseCollection *targetPhraseCollection = NULL; const OnDiskPt::PhraseNode *node = prevNode.GetChild(*sourceLHSBerkeleyDb, m_dbWrapper); if (node) { uint64_t tpCollFilePos = node->GetValue(); std::map<uint64_t, const TargetPhraseCollection*>::const_iterator iterCache = m_cache.find(tpCollFilePos); if (iterCache == m_cache.end()) { const OnDiskPt::TargetPhraseCollection *tpcollBerkeleyDb = node->GetTargetPhraseCollection(m_dictionary.GetTableLimit(), m_dbWrapper); std::vector<float> weightT = staticData.GetWeights(&m_dictionary); targetPhraseCollection = tpcollBerkeleyDb->ConvertToMoses(m_inputFactorsVec ,m_outputFactorsVec ,m_dictionary ,weightT ,m_dbWrapper.GetVocab() ,true); delete tpcollBerkeleyDb; m_cache[tpCollFilePos] = targetPhraseCollection; } else { // just get out of cache targetPhraseCollection = iterCache->second; } UTIL_THROW_IF2(targetPhraseCollection == NULL, "Error"); if (!targetPhraseCollection->IsEmpty()) { AddCompletedRule(prevDottedRule, *targetPhraseCollection, range, outColl); } } // if (node) delete node; delete sourceLHSBerkeleyDb; } } } // for (size_t ind = 0; ind < savedNodeColl.size(); ++ind) //cerr << numDerivations << " "; }
void PhrasePairFeature::EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore) const { const Phrase& source = inputPath.GetPhrase(); if (m_simple) { ostringstream namestr; namestr << "pp_"; namestr << source.GetWord(0).GetFactor(m_sourceFactorId)->GetString(); for (size_t i = 1; i < source.GetSize(); ++i) { const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId); namestr << ","; namestr << sourceFactor->GetString(); } namestr << "~"; namestr << targetPhrase.GetWord(0).GetFactor(m_targetFactorId)->GetString(); for (size_t i = 1; i < targetPhrase.GetSize(); ++i) { const Factor* targetFactor = targetPhrase.GetWord(i).GetFactor(m_targetFactorId); namestr << ","; namestr << targetFactor->GetString(); } scoreBreakdown.SparsePlusEquals(namestr.str(),1); } if (m_domainTrigger) { const Sentence& isnt = static_cast<const Sentence&>(input); const bool use_topicid = isnt.GetUseTopicId(); const bool use_topicid_prob = isnt.GetUseTopicIdAndProb(); // compute pair ostringstream pair; pair << source.GetWord(0).GetFactor(m_sourceFactorId)->GetString(); for (size_t i = 1; i < source.GetSize(); ++i) { const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId); pair << ","; pair << sourceFactor->GetString(); } pair << "~"; pair << targetPhrase.GetWord(0).GetFactor(m_targetFactorId)->GetString(); for (size_t i = 1; i < targetPhrase.GetSize(); ++i) { const Factor* targetFactor = targetPhrase.GetWord(i).GetFactor(m_targetFactorId); pair << ","; pair << targetFactor->GetString(); } if (use_topicid || use_topicid_prob) { if(use_topicid) { // use topicid as trigger const long topicid = isnt.GetTopicId(); stringstream feature; feature << "pp_"; if (topicid == -1) feature << "unk"; else feature << topicid; feature << "_"; feature << pair.str(); scoreBreakdown.SparsePlusEquals(feature.str(), 1); } else { // use topic probabilities const vector<string> &topicid_prob = *(isnt.GetTopicIdAndProb()); if (atol(topicid_prob[0].c_str()) == -1) { stringstream feature; feature << "pp_unk_"; feature << pair.str(); scoreBreakdown.SparsePlusEquals(feature.str(), 1); } else { for (size_t i=0; i+1 < topicid_prob.size(); i+=2) { stringstream feature; feature << "pp_"; feature << topicid_prob[i]; feature << "_"; feature << pair.str(); scoreBreakdown.SparsePlusEquals(feature.str(), atof((topicid_prob[i+1]).c_str())); } } } } else { // range over domain trigger words const long docid = isnt.GetDocumentId(); for (set<string>::const_iterator p = m_vocabDomain[docid].begin(); p != m_vocabDomain[docid].end(); ++p) { string sourceTrigger = *p; ostringstream namestr; namestr << "pp_"; namestr << sourceTrigger; namestr << "_"; namestr << pair.str(); scoreBreakdown.SparsePlusEquals(namestr.str(),1); } } } if (m_sourceContext) { const Sentence& isnt = static_cast<const Sentence&>(input); // range over source words to get context for(size_t contextIndex = 0; contextIndex < isnt.GetSize(); contextIndex++ ) { StringPiece sourceTrigger = isnt.GetWord(contextIndex).GetFactor(m_sourceFactorId)->GetString(); if (m_ignorePunctuation) { // check if trigger is punctuation char firstChar = sourceTrigger[0]; CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); if(charIterator != m_punctuationHash.end()) continue; } bool sourceTriggerExists = false; if (!m_unrestricted) sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end(); if (m_unrestricted || sourceTriggerExists) { ostringstream namestr; namestr << "pp_"; namestr << sourceTrigger; namestr << "~"; namestr << source.GetWord(0).GetFactor(m_sourceFactorId)->GetString(); for (size_t i = 1; i < source.GetSize(); ++i) { const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId); namestr << ","; namestr << sourceFactor->GetString(); } namestr << "~"; namestr << targetPhrase.GetWord(0).GetFactor(m_targetFactorId)->GetString(); for (size_t i = 1; i < targetPhrase.GetSize(); ++i) { const Factor* targetFactor = targetPhrase.GetWord(i).GetFactor(m_targetFactorId); namestr << ","; namestr << targetFactor->GetString(); } scoreBreakdown.SparsePlusEquals(namestr.str(),1); } } } }