bool PhraseDictionaryTransliteration::SatisfyBackoff(const InputPath &inputPath) const
{
  const Phrase &sourcePhrase = inputPath.GetPhrase();

  assert(m_container);
  const DecodeGraph *decodeGraph = m_container->GetContainer();
  size_t backoff = decodeGraph->GetBackoff();

  if (backoff == 0) {
	  // ie. don't backoff. Collect ALL translations
	  return true;
  }

  if (sourcePhrase.GetSize() > backoff) {
	  // source phrase too big
	  return false;
  }

  // lookup translation only if no other translations
  InputPath::TargetPhrases::const_iterator iter;
  for (iter = inputPath.GetTargetPhrases().begin(); iter != inputPath.GetTargetPhrases().end(); ++iter) {
  	const std::pair<const TargetPhraseCollection*, const void*> &temp = iter->second;
  	const TargetPhraseCollection *tpCollPrev = temp.first;

  	if (tpCollPrev && tpCollPrev->GetSize()) {
  		// already have translation from another pt. Don't create translations
  		return false;
  	}
  }

  return true;
}
void PhraseDictionaryTransliteration::GetTargetPhraseCollection(InputPath &inputPath) const
{
    const Phrase &sourcePhrase = inputPath.GetPhrase();
    size_t hash = hash_value(sourcePhrase);

    CacheColl &cache = GetCache();

    std::map<size_t, std::pair<const TargetPhraseCollection*, clock_t> >::iterator iter;
    iter = cache.find(hash);

    if (iter != cache.end()) {
    	// already in cache
    	const TargetPhraseCollection *tpColl = iter->second.first;
    	inputPath.SetTargetPhrases(*this, tpColl, NULL);
    }
    else {
        // TRANSLITERATE
    	char *ptr = tmpnam(NULL);
    	string inFile(ptr);
    	ptr = tmpnam(NULL);
    	string outDir(ptr);

    	ofstream inStream(inFile.c_str());
    	inStream << sourcePhrase.ToString() << endl;
    	inStream.close();

    	string cmd = m_scriptDir + "/Transliteration/prepare-transliteration-phrase-table.pl" +
    			" --transliteration-model-dir " + m_filePath +
    			" --moses-src-dir " + m_mosesDir +
    			" --external-bin-dir " + m_externalDir +
    			" --input-extension " + m_inputLang +
    			" --output-extension " + m_outputLang +
    			" --oov-file " + inFile +
    			" --out-dir " + outDir;

    	int ret = system(cmd.c_str());
    	UTIL_THROW_IF2(ret != 0, "Transliteration script error");

    	TargetPhraseCollection *tpColl = new TargetPhraseCollection();
    	vector<TargetPhrase*> targetPhrases = CreateTargetPhrases(sourcePhrase, outDir);
    	vector<TargetPhrase*>::const_iterator iter;
    	for (iter = targetPhrases.begin(); iter != targetPhrases.end(); ++iter) {
    		TargetPhrase *tp = *iter;
    		tpColl->Add(tp);
    	}

    	std::pair<const TargetPhraseCollection*, clock_t> value(tpColl, clock());
    	cache[hash] = value;

    	inputPath.SetTargetPhrases(*this, tpColl, NULL);

    	// clean up temporary files
    	remove(inFile.c_str());

    	cmd = "rm -rf " + outDir;
    	system(cmd.c_str());
    }
}
Beispiel #3
0
void Manager::CreateInputPaths(const InputPath &prevPath, size_t pos)
{
  if (pos >= m_sentence.GetSize()) {
    return;
  }

  Phrase *phrase = new Phrase(prevPath.GetPhrase(), 1);
  phrase->SetLastWord(m_sentence.GetWord(pos));

  InputPath *path = new InputPath(&prevPath, phrase, pos);
  m_inputPathQueue.push_back(path);

  CreateInputPaths(*path, pos + 1);
}
void PhraseDictionaryOnDisk::GetTargetPhraseCollectionBatch(InputPath &inputPath) const
{
  OnDiskPt::OnDiskWrapper &wrapper = const_cast<OnDiskPt::OnDiskWrapper&>(GetImplementation());
  const Phrase &phrase = inputPath.GetPhrase();
  const InputPath *prevInputPath = inputPath.GetPrevPath();

  const OnDiskPt::PhraseNode *prevPtNode = NULL;

  if (prevInputPath) {
    prevPtNode = static_cast<const OnDiskPt::PhraseNode*>(prevInputPath->GetPtNode(*this));
  } else {
    // Starting subphrase.
    assert(phrase.GetSize() == 1);
    prevPtNode = &wrapper.GetRootSourceNode();
  }

  // backoff
  if (!SatisfyBackoff(inputPath)) {
    return;
  }

  if (prevPtNode) {
    Word lastWord = phrase.GetWord(phrase.GetSize() - 1);
    lastWord.OnlyTheseFactors(m_inputFactors);
    OnDiskPt::Word *lastWordOnDisk = wrapper.ConvertFromMoses(m_input, lastWord);

    if (lastWordOnDisk == NULL) {
      // OOV according to this phrase table. Not possible to extend
      inputPath.SetTargetPhrases(*this, NULL, NULL);
    } else {
      const OnDiskPt::PhraseNode *ptNode = prevPtNode->GetChild(*lastWordOnDisk, wrapper);
      if (ptNode) {
        const TargetPhraseCollection *targetPhrases = GetTargetPhraseCollection(ptNode);
        inputPath.SetTargetPhrases(*this, targetPhrases, ptNode);
      } else {
        inputPath.SetTargetPhrases(*this, NULL, NULL);
      }

      delete lastWordOnDisk;
    }
  }
}
void WordTranslationFeature::EvaluateWithSourceContext(const InputType &input
    , const InputPath &inputPath
    , const TargetPhrase &targetPhrase
    , const StackVec *stackVec
    , ScoreComponentCollection &scoreBreakdown
    , ScoreComponentCollection *estimatedScores) const
{
  const Sentence& sentence = static_cast<const Sentence&>(input);
  const AlignmentInfo &alignment = targetPhrase.GetAlignTerm();

  // process aligned words
  for (AlignmentInfo::const_iterator alignmentPoint = alignment.begin(); alignmentPoint != alignment.end(); alignmentPoint++) {
    const Phrase& sourcePhrase = inputPath.GetPhrase();
    int sourceIndex = alignmentPoint->first;
    int targetIndex = alignmentPoint->second;
    Word ws = sourcePhrase.GetWord(sourceIndex);
    if (m_factorTypeSource == 0 && ws.IsNonTerminal()) continue;
    Word wt = targetPhrase.GetWord(targetIndex);
    if (m_factorTypeSource == 0 && wt.IsNonTerminal()) continue;
    StringPiece sourceWord = ws.GetFactor(m_factorTypeSource)->GetString();
    StringPiece targetWord = wt.GetFactor(m_factorTypeTarget)->GetString();
    if (m_ignorePunctuation) {
      // check if source or target are punctuation
      char firstChar = sourceWord[0];
      CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
      if(charIterator != m_punctuationHash.end())
        continue;
      firstChar = targetWord[0];
      charIterator = m_punctuationHash.find( firstChar );
      if(charIterator != m_punctuationHash.end())
        continue;
    }

    if (!m_unrestricted) {
      if (FindStringPiece(m_vocabSource, sourceWord) == m_vocabSource.end())
        sourceWord = "OTHER";
      if (FindStringPiece(m_vocabTarget, targetWord) == m_vocabTarget.end())
        targetWord = "OTHER";
    }

    if (m_simple) {
      // construct feature name
      util::StringStream featureName;
      featureName << m_description << "_";
      featureName << sourceWord;
      featureName << "~";
      featureName << targetWord;
      scoreBreakdown.SparsePlusEquals(featureName.str(), 1);
    }
    if (m_domainTrigger && !m_sourceContext) {
      const bool use_topicid = sentence.GetUseTopicId();
      const bool use_topicid_prob = sentence.GetUseTopicIdAndProb();
      if (use_topicid || use_topicid_prob) {
        if(use_topicid) {
          // use topicid as trigger
          const long topicid = sentence.GetTopicId();
          util::StringStream feature;
          feature << m_description << "_";
          if (topicid == -1)
            feature << "unk";
          else
            feature << topicid;

          feature << "_";
          feature << sourceWord;
          feature << "~";
          feature << targetWord;
          scoreBreakdown.SparsePlusEquals(feature.str(), 1);
        } else {
          // use topic probabilities
          const vector<string> &topicid_prob = *(input.GetTopicIdAndProb());
          if (atol(topicid_prob[0].c_str()) == -1) {
            util::StringStream feature;
            feature << m_description << "_unk_";
            feature << sourceWord;
            feature << "~";
            feature << targetWord;
            scoreBreakdown.SparsePlusEquals(feature.str(), 1);
          } else {
            for (size_t i=0; i+1 < topicid_prob.size(); i+=2) {
              util::StringStream feature;
              feature << m_description << "_";
              feature << topicid_prob[i];
              feature << "_";
              feature << sourceWord;
              feature << "~";
              feature << targetWord;
              scoreBreakdown.SparsePlusEquals(feature.str(), atof((topicid_prob[i+1]).c_str()));
            }
          }
        }
      } else {
        // range over domain trigger words (keywords)
        const long docid = input.GetDocumentId();
        for (boost::unordered_set<std::string>::const_iterator p = m_vocabDomain[docid].begin(); p != m_vocabDomain[docid].end(); ++p) {
          string sourceTrigger = *p;
          util::StringStream feature;
          feature << m_description << "_";
          feature << sourceTrigger;
          feature << "_";
          feature << sourceWord;
          feature << "~";
          feature << targetWord;
          scoreBreakdown.SparsePlusEquals(feature.str(), 1);
        }
      }
    }
    if (m_sourceContext) {
      size_t globalSourceIndex = inputPath.GetWordsRange().GetStartPos() + sourceIndex;
      if (!m_domainTrigger && globalSourceIndex == 0) {
        // add <s> trigger feature for source
        util::StringStream feature;
        feature << m_description << "_";
        feature << "<s>,";
        feature << sourceWord;
        feature << "~";
        feature << targetWord;
        scoreBreakdown.SparsePlusEquals(feature.str(), 1);
      }

      // range over source words to get context
      for(size_t contextIndex = 0; contextIndex < input.GetSize(); contextIndex++ ) {
        if (contextIndex == globalSourceIndex) continue;
        StringPiece sourceTrigger = input.GetWord(contextIndex).GetFactor(m_factorTypeSource)->GetString();
        if (m_ignorePunctuation) {
          // check if trigger is punctuation
          char firstChar = sourceTrigger[0];
          CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
          if(charIterator != m_punctuationHash.end())
            continue;
        }

        const long docid = input.GetDocumentId();
        bool sourceTriggerExists = false;
        if (m_domainTrigger)
          sourceTriggerExists = FindStringPiece(m_vocabDomain[docid], sourceTrigger ) != m_vocabDomain[docid].end();
        else if (!m_unrestricted)
          sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();

        if (m_domainTrigger) {
          if (sourceTriggerExists) {
            util::StringStream feature;
            feature << m_description << "_";
            feature << sourceTrigger;
            feature << "_";
            feature << sourceWord;
            feature << "~";
            feature << targetWord;
            scoreBreakdown.SparsePlusEquals(feature.str(), 1);
          }
        } else if (m_unrestricted || sourceTriggerExists) {
          util::StringStream feature;
          feature << m_description << "_";
          if (contextIndex < globalSourceIndex) {
            feature << sourceTrigger;
            feature << ",";
            feature << sourceWord;
          } else {
            feature << sourceWord;
            feature << ",";
            feature << sourceTrigger;
          }
          feature << "~";
          feature << targetWord;
          scoreBreakdown.SparsePlusEquals(feature.str(), 1);
        }
      }
    }
    if (m_targetContext) {
      throw runtime_error("Can't use target words outside current translation option in a stateless feature");
      /*
      size_t globalTargetIndex = cur_hypo.GetCurrTargetWordsRange().GetStartPos() + targetIndex;
      if (globalTargetIndex == 0) {
      	// add <s> trigger feature for source
      	stringstream feature;
      	feature << "wt_";
      	feature << sourceWord;
      	feature << "~";
      	feature << "<s>,";
      	feature << targetWord;
      	accumulator->SparsePlusEquals(feature.str(), 1);
      }

      // range over target words (up to current position) to get context
      for(size_t contextIndex = 0; contextIndex < globalTargetIndex; contextIndex++ ) {
      	string targetTrigger = cur_hypo.GetWord(contextIndex).GetFactor(m_factorTypeTarget)->GetString();
      	if (m_ignorePunctuation) {
      		// check if trigger is punctuation
      		char firstChar = targetTrigger.at(0);
      		CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
      		if(charIterator != m_punctuationHash.end())
      			continue;
      	}

      	bool targetTriggerExists = false;
      	if (!m_unrestricted)
      		targetTriggerExists = m_vocabTarget.find( targetTrigger ) != m_vocabTarget.end();

      	if (m_unrestricted || targetTriggerExists) {
      		stringstream feature;
      		feature << "wt_";
      		feature << sourceWord;
      		feature << "~";
      		feature << targetTrigger;
      		feature << ",";
      		feature << targetWord;
      		accumulator->SparsePlusEquals(feature.str(), 1);
      	}
      }*/
    }
  }
}
void PhrasePairFeature::EvaluateWithSourceContext(const InputType &input
    , const InputPath &inputPath
    , const TargetPhrase &targetPhrase
    , const StackVec *stackVec
    , ScoreComponentCollection &scoreBreakdown
    , ScoreComponentCollection *estimatedFutureScore) const
{
  const Phrase& source = inputPath.GetPhrase();
  if (m_simple) {
    ostringstream namestr;
    namestr << "pp_";
    namestr << source.GetWord(0).GetFactor(m_sourceFactorId)->GetString();
    for (size_t i = 1; i < source.GetSize(); ++i) {
      const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId);
      namestr << ",";
      namestr << sourceFactor->GetString();
    }
    namestr << "~";
    namestr << targetPhrase.GetWord(0).GetFactor(m_targetFactorId)->GetString();
    for (size_t i = 1; i < targetPhrase.GetSize(); ++i) {
      const Factor* targetFactor = targetPhrase.GetWord(i).GetFactor(m_targetFactorId);
      namestr << ",";
      namestr << targetFactor->GetString();
    }

    scoreBreakdown.SparsePlusEquals(namestr.str(),1);
  }
  if (m_domainTrigger) {
    const Sentence& isnt = static_cast<const Sentence&>(input);
    const bool use_topicid = isnt.GetUseTopicId();
    const bool use_topicid_prob = isnt.GetUseTopicIdAndProb();

    // compute pair
    ostringstream pair;
    pair << source.GetWord(0).GetFactor(m_sourceFactorId)->GetString();
    for (size_t i = 1; i < source.GetSize(); ++i) {
      const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId);
      pair << ",";
      pair << sourceFactor->GetString();
    }
    pair << "~";
    pair << targetPhrase.GetWord(0).GetFactor(m_targetFactorId)->GetString();
    for (size_t i = 1; i < targetPhrase.GetSize(); ++i) {
      const Factor* targetFactor = targetPhrase.GetWord(i).GetFactor(m_targetFactorId);
      pair << ",";
      pair << targetFactor->GetString();
    }

    if (use_topicid || use_topicid_prob) {
      if(use_topicid) {
        // use topicid as trigger
        const long topicid = isnt.GetTopicId();
        stringstream feature;
        feature << "pp_";
        if (topicid == -1)
          feature << "unk";
        else
          feature << topicid;

        feature << "_";
        feature << pair.str();
        scoreBreakdown.SparsePlusEquals(feature.str(), 1);
      } else {
        // use topic probabilities
        const vector<string> &topicid_prob = *(isnt.GetTopicIdAndProb());
        if (atol(topicid_prob[0].c_str()) == -1) {
          stringstream feature;
          feature << "pp_unk_";
          feature << pair.str();
          scoreBreakdown.SparsePlusEquals(feature.str(), 1);
        } else {
          for (size_t i=0; i+1 < topicid_prob.size(); i+=2) {
            stringstream feature;
            feature << "pp_";
            feature << topicid_prob[i];
            feature << "_";
            feature << pair.str();
            scoreBreakdown.SparsePlusEquals(feature.str(), atof((topicid_prob[i+1]).c_str()));
          }
        }
      }
    } else {
      // range over domain trigger words
      const long docid = isnt.GetDocumentId();
      for (set<string>::const_iterator p = m_vocabDomain[docid].begin(); p != m_vocabDomain[docid].end(); ++p) {
        string sourceTrigger = *p;
        ostringstream namestr;
        namestr << "pp_";
        namestr << sourceTrigger;
        namestr << "_";
        namestr << pair.str();
        scoreBreakdown.SparsePlusEquals(namestr.str(),1);
      }
    }
  }
  if (m_sourceContext) {
    const Sentence& isnt = static_cast<const Sentence&>(input);

    // range over source words to get context
    for(size_t contextIndex = 0; contextIndex < isnt.GetSize(); contextIndex++ ) {
      StringPiece sourceTrigger = isnt.GetWord(contextIndex).GetFactor(m_sourceFactorId)->GetString();
      if (m_ignorePunctuation) {
        // check if trigger is punctuation
        char firstChar = sourceTrigger[0];
        CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
        if(charIterator != m_punctuationHash.end())
          continue;
      }

      bool sourceTriggerExists = false;
      if (!m_unrestricted)
        sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();

      if (m_unrestricted || sourceTriggerExists) {
        ostringstream namestr;
        namestr << "pp_";
        namestr << sourceTrigger;
        namestr << "~";
        namestr << source.GetWord(0).GetFactor(m_sourceFactorId)->GetString();
        for (size_t i = 1; i < source.GetSize(); ++i) {
          const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId);
          namestr << ",";
          namestr << sourceFactor->GetString();
        }
        namestr << "~";
        namestr << targetPhrase.GetWord(0).GetFactor(m_targetFactorId)->GetString();
        for (size_t i = 1; i < targetPhrase.GetSize(); ++i) {
          const Factor* targetFactor = targetPhrase.GetWord(i).GetFactor(m_targetFactorId);
          namestr << ",";
          namestr << targetFactor->GetString();
        }

        scoreBreakdown.SparsePlusEquals(namestr.str(),1);
      }
    }
  }
}