コード例 #1
0
TargetPhrase*
BilingualDynSuffixArray::
GetMosesFactorIDs(const SAPhrase& phrase, const Phrase& sourcePhrase) const
{
  TargetPhrase* targetPhrase = new TargetPhrase();
  for(size_t i=0; i < phrase.words.size(); ++i) { // look up trg words
    Word& word = m_trgVocab->GetWord( phrase.words[i]);
    CHECK(word != m_trgVocab->GetkOOVWord());
    targetPhrase->AddWord(word);
  }
  targetPhrase->SetSourcePhrase(sourcePhrase);
  // scoring
  return targetPhrase;
}
コード例 #2
0
ファイル: LoaderStandard.cpp プロジェクト: Avmb/mosesdecoder
bool RuleTableLoaderStandard::Load(FormatType format
                                , const std::vector<FactorType> &input
                                , const std::vector<FactorType> &output
                                , const std::string &inFile
                                , const std::vector<float> &weight
                                , size_t /* tableLimit */
                                , const LMList &languageModels
                                , const WordPenaltyProducer* wpProducer
                                , RuleTableTrie &ruleTable)
{
  PrintUserTime(string("Start loading text SCFG phrase table. ") + (format==MosesFormat?"Moses ":"Hiero ") + " format");

  const StaticData &staticData = StaticData::Instance();
  const std::string& factorDelimiter = staticData.GetFactorDelimiter();

  string lineOrig;
  size_t count = 0;

  std::ostream *progress = NULL;
  IFVERBOSE(1) progress = &std::cerr;
  util::FilePiece in(inFile.c_str(), progress);

  // reused variables
  vector<float> scoreVector;
  StringPiece line;
  std::string hiero_before, hiero_after;

  while(true) {
    try {
      line = in.ReadLine();
    } catch (const util::EndOfFileException &e) { break; }

    if (format == HieroFormat) { // inefficiently reformat line
      hiero_before.assign(line.data(), line.size());
      ReformatHieroRule(hiero_before, hiero_after);
      line = hiero_after;
    }

    util::TokenIter<util::MultiCharacter> pipes(line, "|||");
    StringPiece sourcePhraseString(*pipes);
    StringPiece targetPhraseString(*++pipes);
    StringPiece scoreString(*++pipes);
    StringPiece alignString(*++pipes);
    // TODO(bhaddow) efficiently handle default instead of parsing this string every time.  
    StringPiece ruleCountString = ++pipes ? *pipes : StringPiece("1 1");
    
    if (++pipes) {
      stringstream strme;
      strme << "Syntax error at " << ruleTable.GetFilePath() << ":" << count;
      UserMessage::Add(strme.str());
      abort();
    }

    bool isLHSEmpty = (sourcePhraseString.find_first_not_of(" \t", 0) == string::npos);
    if (isLHSEmpty && !staticData.IsWordDeletionEnabled()) {
      TRACE_ERR( ruleTable.GetFilePath() << ":" << count << ": pt entry contains empty target, skipping\n");
      continue;
    }

    scoreVector.clear();
    for (util::TokenIter<util::AnyCharacter, true> s(scoreString, " \t"); s; ++s) {
      char *err_ind;
      scoreVector.push_back(strtod(s->data(), &err_ind));
      UTIL_THROW_IF(err_ind == s->data(), util::Exception, "Bad score " << *s << " on line " << count);
    }
    const size_t numScoreComponents = ruleTable.GetFeature()->GetNumScoreComponents();
    if (scoreVector.size() != numScoreComponents) {
      stringstream strme;
      strme << "Size of scoreVector != number (" << scoreVector.size() << "!="
            << numScoreComponents << ") of score components on line " << count;
      UserMessage::Add(strme.str());
      abort();
    }

    // parse source & find pt node

    // constituent labels
    Word sourceLHS, targetLHS;

    // source
    Phrase sourcePhrase( 0);
    sourcePhrase.CreateFromStringNewFormat(Input, input, sourcePhraseString, factorDelimiter, sourceLHS);

    // create target phrase obj
    TargetPhrase *targetPhrase = new TargetPhrase(Output);
    targetPhrase->CreateFromStringNewFormat(Output, output, targetPhraseString, factorDelimiter, targetLHS);
    targetPhrase->SetSourcePhrase(sourcePhrase);

    // rest of target phrase
    targetPhrase->SetAlignmentInfo(alignString, sourcePhrase);
    targetPhrase->SetTargetLHS(targetLHS);
    
    targetPhrase->SetRuleCount(ruleCountString, scoreVector[0]);
    //targetPhrase->SetDebugOutput(string("New Format pt ") + line);
    
    // component score, for n-best output
    std::transform(scoreVector.begin(),scoreVector.end(),scoreVector.begin(),TransformScore);
    std::transform(scoreVector.begin(),scoreVector.end(),scoreVector.begin(),FloorScore);

    targetPhrase->SetScoreChart(ruleTable.GetFeature(), scoreVector, weight, languageModels,wpProducer);

    TargetPhraseCollection &phraseColl = GetOrCreateTargetPhraseCollection(ruleTable, sourcePhrase, *targetPhrase, sourceLHS);
    phraseColl.Add(targetPhrase);

    count++;
  }

  // sort and prune each target phrase collection
  SortAndPrune(ruleTable);

  return true;
}
コード例 #3
0
ファイル: PhraseDecoder.cpp プロジェクト: Avmb/mosesdecoder
TargetPhraseVectorPtr PhraseDecoder::DecodeCollection(
  TargetPhraseVectorPtr tpv, BitWrapper<> &encodedBitStream,
  const Phrase &sourcePhrase, bool topLevel)
{
  
  bool extending = tpv->size();
  size_t bitsLeft = encodedBitStream.TellFromEnd();
    
  typedef std::pair<size_t, size_t> AlignPointSizeT;
  
  std::vector<int> sourceWords;
  if(m_coding == REnc)
  {
    for(size_t i = 0; i < sourcePhrase.GetSize(); i++)
    {
      std::string sourceWord
        = sourcePhrase.GetWord(i).GetString(*m_input, false);
      unsigned idx = GetSourceSymbolId(sourceWord);
      sourceWords.push_back(idx);
    }
  }
  
  unsigned phraseStopSymbol = 0;
  AlignPoint alignStopSymbol(-1, -1);
  
  std::vector<float> scores;
  std::set<AlignPointSizeT> alignment;
  
  enum DecodeState { New, Symbol, Score, Alignment, Add } state = New;
  
  size_t srcSize = sourcePhrase.GetSize();
  
  TargetPhrase* targetPhrase = NULL;
  while(encodedBitStream.TellFromEnd())
  {
     
    if(state == New)
    {
      // Creating new TargetPhrase on the heap
      tpv->push_back(TargetPhrase(Output));
      targetPhrase = &tpv->back();
      
      targetPhrase->SetSourcePhrase(sourcePhrase);
      alignment.clear();
      scores.clear();
        
      state = Symbol;
    }
    
    if(state == Symbol)
    {
      unsigned symbol = m_symbolTree->Read(encodedBitStream);      
      if(symbol == phraseStopSymbol)
      {
        state = Score;
      }
      else
      {
        if(m_coding == REnc)
        {
          std::string wordString;
          size_t type = GetREncType(symbol);
          
          if(type == 1)
          {
            unsigned decodedSymbol = DecodeREncSymbol1(symbol);
            wordString = GetTargetSymbol(decodedSymbol);
          }
          else if (type == 2)
          {
            size_t rank = DecodeREncSymbol2Rank(symbol);
            size_t srcPos = DecodeREncSymbol2Position(symbol);
            
            if(srcPos >= sourceWords.size())
              return TargetPhraseVectorPtr();  
            
            wordString = GetTargetSymbol(GetTranslation(sourceWords[srcPos], rank));
            if(m_phraseDictionary.m_useAlignmentInfo)
            {
              size_t trgPos = targetPhrase->GetSize();
              alignment.insert(AlignPoint(srcPos, trgPos));
            }
          }
          else if(type == 3)
          {
            size_t rank = DecodeREncSymbol3(symbol);
            size_t srcPos = targetPhrase->GetSize();
            
            if(srcPos >= sourceWords.size())
              return TargetPhraseVectorPtr();  
                            
            wordString = GetTargetSymbol(GetTranslation(sourceWords[srcPos], rank));   
            if(m_phraseDictionary.m_useAlignmentInfo)
            {
              size_t trgPos = srcPos;
              alignment.insert(AlignPoint(srcPos, trgPos));
            }
          }
          
          Word word;
          word.CreateFromString(Output, *m_output, wordString, false);
          targetPhrase->AddWord(word);
        }
        else if(m_coding == PREnc)
        {
          // if the symbol is just a word
          if(GetPREncType(symbol) == 1)
          {
            unsigned decodedSymbol = DecodePREncSymbol1(symbol);
     
            Word word;
            word.CreateFromString(Output, *m_output,
                                  GetTargetSymbol(decodedSymbol), false);
            targetPhrase->AddWord(word);
          }
          // if the symbol is a subphrase pointer
          else
          {
            int left = DecodePREncSymbol2Left(symbol);
            int right = DecodePREncSymbol2Right(symbol);
            unsigned rank = DecodePREncSymbol2Rank(symbol);
            
            int srcStart = left + targetPhrase->GetSize();
            int srcEnd   = srcSize - right - 1;
            
            // false positive consistency check
            if(0 > srcStart || srcStart > srcEnd || unsigned(srcEnd) >= srcSize)
              return TargetPhraseVectorPtr();
            
            // false positive consistency check
            if(m_maxRank && rank > m_maxRank)
                return TargetPhraseVectorPtr();
            
            // set subphrase by default to itself
            TargetPhraseVectorPtr subTpv = tpv;
            
            // if range smaller than source phrase retrieve subphrase
            if(unsigned(srcEnd - srcStart + 1) != srcSize)
            {
              Phrase subPhrase = sourcePhrase.GetSubString(WordsRange(srcStart, srcEnd));
              subTpv = CreateTargetPhraseCollection(subPhrase, false);
            }
            
            // false positive consistency check
            if(subTpv != NULL && rank < subTpv->size())
            {
              // insert the subphrase into the main target phrase
              TargetPhrase& subTp = subTpv->at(rank);
              if(m_phraseDictionary.m_useAlignmentInfo)
              {
                // reconstruct the alignment data based on the alignment of the subphrase
                for(AlignmentInfo::const_iterator it = subTp.GetAlignmentInfo().begin();
                    it != subTp.GetAlignmentInfo().end(); it++)
                {
                  alignment.insert(AlignPointSizeT(srcStart + it->first,
                                                   targetPhrase->GetSize() + it->second));
                }
              }
              targetPhrase->Append(subTp);
            }
            else 
              return TargetPhraseVectorPtr();
          }
        }
        else
        {
            Word word;
            word.CreateFromString(Output, *m_output,
                                  GetTargetSymbol(symbol), false);
            targetPhrase->AddWord(word);
        }
      }
    }
    else if(state == Score)
    {
      size_t idx = m_multipleScoreTrees ? scores.size() : 0;
      float score = m_scoreTrees[idx]->Read(encodedBitStream);
      scores.push_back(score);
      
      if(scores.size() == m_numScoreComponent)
      {
        targetPhrase->SetScore(m_feature, scores, ScoreComponentCollection() /*sparse*/,*m_weight, m_weightWP, *m_languageModels);
        
        if(m_containsAlignmentInfo)
          state = Alignment;
        else
          state = Add;
      }
    }
    else if(state == Alignment)
    {
      AlignPoint alignPoint = m_alignTree->Read(encodedBitStream);
      if(alignPoint == alignStopSymbol)
      {
        state = Add;
      }
      else
      {
        if(m_phraseDictionary.m_useAlignmentInfo)  
          alignment.insert(AlignPointSizeT(alignPoint));
      }
    }
    
    if(state == Add)
    {
      if(m_phraseDictionary.m_useAlignmentInfo)
        targetPhrase->SetAlignmentInfo(alignment);
      
      if(m_coding == PREnc)
      {
        if(!m_maxRank || tpv->size() <= m_maxRank)
          bitsLeft = encodedBitStream.TellFromEnd();
        
        if(!topLevel && m_maxRank && tpv->size() >= m_maxRank)
          break;
      }
      
      if(encodedBitStream.TellFromEnd() <= 8)
        break;
      
      state = New;
    }    
  }
  
  if(m_coding == PREnc && !extending)
  {
    bitsLeft = bitsLeft > 8 ? bitsLeft : 0;
    m_decodingCache.Cache(sourcePhrase, tpv, bitsLeft, m_maxRank);
  }
  
  return tpv;
}
コード例 #4
0
bool RuleTableLoaderCompact::LoadRuleSection(
  LineReader &reader,
  const std::vector<Word> &vocab,
  const std::vector<Phrase> &sourcePhrases,
  const std::vector<Phrase> &targetPhrases,
  const std::vector<size_t> &targetLhsIds,
  const std::vector<const AlignmentInfo *> &alignmentSets,
  RuleTableTrie &ruleTable)
{
  // Read rule count.
  reader.ReadLine();
  const size_t ruleCount = std::atoi(reader.m_line.c_str());

  // Read rules and add to table.
  const size_t numScoreComponents = ruleTable.GetNumScoreComponents();
  std::vector<float> scoreVector(numScoreComponents);
  std::vector<size_t> tokenPositions;
  for (size_t i = 0; i < ruleCount; ++i) {
    reader.ReadLine();

    tokenPositions.clear();
    FindTokens(tokenPositions, reader.m_line);

    const char *charLine = reader.m_line.c_str();

    // The first three tokens are IDs for the source phrase, target phrase,
    // and alignment set.
    const int sourcePhraseId = std::atoi(charLine+tokenPositions[0]);
    const int targetPhraseId = std::atoi(charLine+tokenPositions[1]);
    const int alignmentSetId = std::atoi(charLine+tokenPositions[2]);

    const Phrase &sourcePhrase = sourcePhrases[sourcePhraseId];
    const Phrase &targetPhrasePhrase = targetPhrases[targetPhraseId];
    const Word *targetLhs = new Word(vocab[targetLhsIds[targetPhraseId]]);
    Word sourceLHS("X"); // TODO not implemented for compact
    const AlignmentInfo *alignNonTerm = alignmentSets[alignmentSetId];

    // Then there should be one score for each score component.
    for (size_t j = 0; j < numScoreComponents; ++j) {
      float score = std::atof(charLine+tokenPositions[3+j]);
      scoreVector[j] = FloorScore(TransformScore(score));
    }
    if (reader.m_line[tokenPositions[3+numScoreComponents]] != ':') {
      std::stringstream msg;
      msg << "Size of scoreVector != number ("
          << scoreVector.size() << "!=" << numScoreComponents
          << ") of score components on line " << reader.m_lineNum;
      UserMessage::Add(msg.str());
      return false;
    }

    // The remaining columns are currently ignored.

    // Create and score target phrase.
    TargetPhrase *targetPhrase = new TargetPhrase(targetPhrasePhrase);
    targetPhrase->SetAlignNonTerm(alignNonTerm);
    targetPhrase->SetTargetLHS(targetLhs);
    targetPhrase->SetSourcePhrase(sourcePhrase);

    targetPhrase->Evaluate(sourcePhrase, ruleTable.GetFeaturesToApply());

    // Insert rule into table.
    TargetPhraseCollection &coll = GetOrCreateTargetPhraseCollection(
                                     ruleTable, sourcePhrase, *targetPhrase, &sourceLHS);
    coll.Add(targetPhrase);
  }

  return true;
}
コード例 #5
0
  const TargetPhraseCollection*
     PhraseDictionaryInterpolated::GetTargetPhraseCollection(const Phrase& src) const {

    delete m_targetPhrases;
    m_targetPhrases = new TargetPhraseCollection();
    PhraseSet allPhrases;
    vector<PhraseSet> phrasesByTable(m_dictionaries.size());
    for (size_t i = 0; i < m_dictionaries.size(); ++i) {
      const TargetPhraseCollection* phrases = m_dictionaries[i]->GetTargetPhraseCollection(src);
      if (phrases) {
        for (TargetPhraseCollection::const_iterator j = phrases->begin(); 
          j != phrases->end(); ++j) {
          allPhrases.insert(*j);
          phrasesByTable[i].insert(*j);
        }
      }
    }
    ScoreComponentCollection sparseVector;
    for (PhraseSet::const_iterator i = allPhrases.begin(); i != allPhrases.end(); ++i) {
      TargetPhrase* combinedPhrase = new TargetPhrase((Phrase)**i);
      //combinedPhrase->ResetScore();
      //cerr << *combinedPhrase << " " << combinedPhrase->GetScoreBreakdown() << endl;
      combinedPhrase->SetSourcePhrase((*i)->GetSourcePhrase());
      combinedPhrase->SetAlignTerm(&((*i)->GetAlignTerm()));
      combinedPhrase->SetAlignNonTerm(&((*i)->GetAlignTerm()));
      Scores combinedScores(GetFeature()->GetNumScoreComponents());
      for (size_t j = 0; j < phrasesByTable.size(); ++j) {
        PhraseSet::const_iterator tablePhrase = phrasesByTable[j].find(combinedPhrase);
        if (tablePhrase != phrasesByTable[j].end()) {
          Scores tableScores = (*tablePhrase)->GetScoreBreakdown()
            .GetScoresForProducer(GetFeature());
          //cerr << "Scores from " << j << " table: ";
          for (size_t k = 0; k < tableScores.size()-1; ++k) {
            //cerr << tableScores[k] << "(" << exp(tableScores[k]) << ") ";
            combinedScores[k] += m_weights[k][j] * exp(tableScores[k]);
            //cerr << m_weights[k][j] * exp(tableScores[k]) << " ";
          }
          //cerr << endl;
        }
      }
      //map back to log space
      //cerr << "Combined ";
      for (size_t k = 0; k < combinedScores.size()-1; ++k) {
        //cerr << combinedScores[k] << " ";
        combinedScores[k] = log(combinedScores[k]);
        //cerr << combinedScores[k] << " ";
      }
      //cerr << endl;
      combinedScores.back() = 1; //assume last is penalty
      combinedPhrase->SetScore(
        GetFeature(),
        combinedScores,
        sparseVector,
        m_weightT,
        m_weightWP,
        *m_languageModels);
      //cerr << *combinedPhrase << " " << combinedPhrase->GetScoreBreakdown() <<  endl;
      m_targetPhrases->Add(combinedPhrase);
    }

    m_targetPhrases->Prune(true,m_tableLimit);


    return m_targetPhrases;
  }