Exemplo n.º 1
0
void CConParser::getPositiveFeatures( const CSentenceParsed &correct ) {
   static CTwoStringVector sentence;
   static CStateItem states[MAX_SENTENCE_SIZE*(1+UNARY_MOVES)+2];
   static CPackedScoreType<SCORE_TYPE, CAction::MAX> scores;
   static int current;
   static CAction action;

   states[0].clear();
   current = 0;
   UnparseSentence( &correct, &sentence ) ;
   m_lCache.clear();
   m_lWordLen.clear();
   for (unsigned i=0; i<sentence.size(); ++i) {
      m_lCache.push_back( CTaggedWord<CTag, TAG_SEPARATOR>(sentence[i].first , sentence[i].second) );
      m_lWordLen.push_back( getUTF8StringLength(sentence[i].first) );
   }

   while ( !states[current].IsTerminated() ) {
      states[current].StandardMove(correct, action);
//std::cout << action << std::endl;
      m_Context.load(states+current, m_lCache, m_lWordLen, true);
      getOrUpdateStackScore(static_cast<CWeight*>(m_weights), scores, states+current, action, 1, -1);
      states[current].Move(states+current+1, action);
      ++current;
   }
}
Exemplo n.º 2
0
void CConParser::work(const CTwoStringVector &sentence, const CSentenceParsed &correct, CCoNLLOutput *o_conll){
	const int length = sentence.size();
	static int tmp_i,tmp_j;
	static CAction correct_action;
	m_lCache.clear();
	m_lWordLen.clear();
	for (tmp_i=0; tmp_i<length; tmp_i++ ) {
		m_lCache.push_back( CTaggedWord<CTag, TAG_SEPARATOR>(sentence[tmp_i].first , sentence[tmp_i].second) );
		m_lWordLen.push_back( getUTF8StringLength(sentence[tmp_i].first) );
	}
	std::vector<CStateItem> p(MAX_SENTENCE_SIZE*(2+UNARY_MOVES)+2);
	CStateItem *correctState = &p[0];
	//getLabeledBrackets(correct, correctState->gold_lb);
	correctState->clear();
	correctState->words = (&m_lCache);
	tmp_i = 1;
	while(true){
		correctState->StandardMove(correct, correct_action);
		//std::cerr<<correct_action<<std::endl;
		correctState->Move(&p[tmp_i],correct_action);
		correctState = &p[tmp_i];
		tmp_i ++;
		if (correctState == 0 || correctState->IsTerminated()) break; // while
	}

   correctState->GenerateStanford(sentence, o_conll);
	p.clear();
}
Exemplo n.º 3
0
bool TARGET_LANGUAGE::CTagger::train( const CTwoStringVector * correct ) {
   static int i;
   static CTwoStringVector tagged;

   static CStringVector sentence;
   UntagSentence( correct, &sentence );

   ++m_nTrainingRound;

   updateTagDict(correct);
   tag( &sentence , &tagged , 1 , NULL );
   if ( tagged != *correct ) {
      for (i=0; i<tagged.size(); ++i) 
         updateLocalFeatureVector(eSubtract, &tagged, i, m_nTrainingRound);
      for (i=0; i<correct->size(); ++i)
         updateLocalFeatureVector(eAdd, correct, i, m_nTrainingRound);
      m_bScoreModified = true;
      return true;
   }
   return false;
}
Exemplo n.º 4
0
bool TARGET_LANGUAGE::CTagger::train( const CTwoStringVector * correct ) {
   static int i;
   static CTwoStringVector tagged;

   static CStringVector sentence;
   static bool bDicOOV;
   static unsigned long long possible_tags; // possible tags for a word
   static unsigned long long current_tag;

   UntagSentence( correct, &sentence );

   ++m_nTrainingRound;

   m_CacheTags.clear();
   for (i=0; i<correct->size(); ++i)
      m_CacheTags.push_back(CTag(correct->at(i).second).code());

   tag( &sentence , &tagged , 1 , NULL );
   if ( tagged != *correct ) {
      for (i=0; i<tagged.size(); ++i) {
         bDicOOV = false;
	 if (m_TagDict) {
	    possible_tags = getPossibleTagsForWord(correct->at(i).first);
	    current_tag = (1LL<<CTag(correct->at(i).second).code()) ;
	    if ( ( possible_tags & current_tag ) == 0 ) {
	       WARNING("dictionary does not have the example word " << correct->at(i).first << " with tag " << correct->at(i).second);
               bDicOOV = true;
	    }
	 }
//         if (!bDicOOV) {
            updateLocalFeatureVector(eSubtract, &tagged, i, m_nTrainingRound);
            updateLocalFeatureVector(eAdd, correct, i, m_nTrainingRound);
            m_bScoreModified = true;
//         }
      }
      return true;
   }

   return false;
}
Exemplo n.º 5
0
void CDepParser::work( const bool bTrain , const CTwoStringVector &sentence , CDependencyParse *retval , const CDependencyParse &correct , int nBest , SCORE_TYPE *scores ) {

#ifdef DEBUG
   clock_t total_start_time = clock();
#endif
   static int index;
   const int length = sentence.size() ; 

   const CStateItem *pGenerator ;
   static CStateItem pCandidate(&m_lCache) ;

   // used only for training
   static bool bCorrect ;  // used in learning for early update
   static bool bContradictsRules;
   static CStateItem correctState(&m_lCache) ;
   static CPackedScoreType<SCORE_TYPE, action::MAX> packed_scores;

   ASSERT(length<MAX_SENTENCE_SIZE, "The size of the sentence is larger than the system configuration.");

   TRACE("Initialising the decoding process...") ;
   // initialise word cache
   bContradictsRules = false;
   m_lCache.clear();
   for ( index=0; index<length; ++index ) {
      m_lCache.push_back( CTaggedWord<CTag, TAG_SEPARATOR>(sentence[index].first , sentence[index].second) );
      // filter std::cout training examples with rules
      if (bTrain && m_weights->rules()) {
         // the root
         if ( correct[index].head == DEPENDENCY_LINK_NO_HEAD && canBeRoot(m_lCache[index].tag.code())==false) {
            TRACE("Rule contradiction: " << m_lCache[index].tag.code() << " can be root.");
            bContradictsRules = true;
         }
         // head left
         if ( correct[index].head < index && hasLeftHead(m_lCache[index].tag.code())==false) {
            TRACE("Rule contradiction: " << m_lCache[index].tag.code() << " has left head.");
            bContradictsRules = true;
         }
         // head right
         if ( correct[index].head > index && hasRightHead(m_lCache[index].tag.code())==false) {
            TRACE("Rule contradiction: " << m_lCache[index].tag.code() << " has right head.");
            bContradictsRules = true;
         }
      }
   }

   // initialise agenda
   m_Agenda->clear();
   pCandidate.clear();                          // restore state using clean
   m_Agenda->pushCandidate(&pCandidate);           // and push it back
   m_Agenda->nextRound();                       // as the generator item
   if (bTrain) correctState.clear();

   // verifying supertags
   if (m_supertags) {
      ASSERT(m_supertags->getSentenceSize()==length, "Sentence size does not match supertags size");
   }

#ifdef LABELED
   unsigned long label;
   m_lCacheLabel.clear();
   if (bTrain) {
      for (index=0; index<length; ++index) {
         m_lCacheLabel.push_back(CDependencyLabel(correct[index].label));
         if (m_weights->rules() && !canAssignLabel(m_lCache, correct[index].head, index, m_lCacheLabel[index])) {
            TRACE("Rule contradiction: " << correct[index].label << " on link head " << m_lCache[correct[index].head].tag.code() << " dep " << m_lCache[index].tag.code());
            bContradictsRules = true;
         }
      }
   }
#endif

   // skip the training example if contradicts
   if (bTrain && m_weights->rules() && bContradictsRules) {
      std::cout << "Skipping training example because it contradicts rules..." <<std::endl;
      return;
   }

   TRACE("Decoding started"); 
   // loop with the next word to process in the sentence
   for (index=0; index<length*2; ++index) {
      
      if (bTrain) bCorrect = false ; 

      // none can this find with pruning ???
      if (m_Agenda->generatorSize() == 0) {
         WARNING("parsing failed"); 
         return;
      }

      pGenerator = m_Agenda->generatorStart();
      // iterate generators
      for (int j=0; j<m_Agenda->generatorSize(); ++j) {

         // for the state items that already contain all words
         m_Beam->clear();
         packed_scores.reset();
         getOrUpdateStackScore( pGenerator, packed_scores, action::NO_ACTION );
         if ( pGenerator->size() == length ) {
            assert( pGenerator->stacksize() != 0 );
            if ( pGenerator->stacksize()>1 ) {
#ifdef FRAGMENTED_TREE
               if (pGenerator->head(pGenerator->stacktop()) == DEPENDENCY_LINK_NO_HEAD)
                  poproot(pGenerator, packed_scores);
               else
#endif
               reduce(pGenerator, packed_scores) ;
            }
            else {
               poproot(pGenerator, packed_scores); 
            }
         }
         // for the state items that still need more words
         else {  
            if ( !pGenerator->afterreduce() ) { // there are many ways when there are many arcrighted items on the stack and the root need arcleft. force this.               
               if ( 
#ifndef FRAGMENTED_TREE
                    ( pGenerator->size() < length-1 || pGenerator->stackempty() ) && // keep only one global root
#endif
                    ( pGenerator->stackempty() || m_supertags == 0 || m_supertags->canShift( pGenerator->size() ) ) && // supertags
                    ( pGenerator->stackempty() || !m_weights->rules() || canBeRoot( m_lCache[pGenerator->size()].tag.code() ) || hasRightHead(m_lCache[pGenerator->size()].tag.code()) ) // rules
                  ) {
                  shift(pGenerator, packed_scores) ;
               }
            }
            if ( !pGenerator->stackempty() ) {
               if ( 
#ifndef FRAGMENTED_TREE
                    ( pGenerator->size() < length-1 || pGenerator->headstacksize() == 1 ) && // one root
#endif
                    ( m_supertags == 0 || m_supertags->canArcRight(pGenerator->stacktop(), pGenerator->size()) ) && // supertags conform to this action
                    ( !m_weights->rules() || hasLeftHead(m_lCache[pGenerator->size()].tag.code()) ) // rules
                  ) { 
                  arcright(pGenerator, packed_scores) ;
               }
            }
            if ( (!m_bCoNLL && !pGenerator->stackempty()) ||
                 (m_bCoNLL && pGenerator->stacksize()>1) // make sure that for conll the first item is not popped
               ) {
               if ( pGenerator->head( pGenerator->stacktop() ) != DEPENDENCY_LINK_NO_HEAD ) {
                  reduce(pGenerator, packed_scores) ;
               }
               else {
                  if ( (m_supertags == 0 || m_supertags->canArcLeft(pGenerator->size(), pGenerator->stacktop())) && // supertags
                       (!m_weights->rules() || hasRightHead(m_lCache[pGenerator->stacktop()].tag.code())) // rules
                     ) {
                     arcleft(pGenerator, packed_scores) ;
                  }
               }
            }
         }

         // insert item
         for (unsigned i=0; i<m_Beam->size(); ++i) {
            pCandidate = *pGenerator;
            pCandidate.score = m_Beam->item(i)->score;
            pCandidate.Move( m_Beam->item(i)->action );
            m_Agenda->pushCandidate(&pCandidate);
         }

         if (bTrain && *pGenerator == correctState) {
            bCorrect = true ;
         }
         pGenerator = m_Agenda->generatorNext() ;

      }
      // when we are doing training, we need to consider the standard move and update
      if (bTrain) {
#ifdef EARLY_UPDATE
         if (!bCorrect) {
            TRACE("Error at the "<<correctState.size()<<"th word; total is "<<correct.size())
            updateScoresForStates(m_Agenda->bestGenerator(), &correctState, 1, -1) ; 
#ifndef LOCAL_LEARNING
            return ;
#else
            m_Agenda->clearCandidates();
            m_Agenda->pushCandidate(&correctState);
#endif
         }
#endif

         if (bCorrect) {
#ifdef LABELED
            correctState.StandardMoveStep(correct, m_lCacheLabel);
#else
            correctState.StandardMoveStep(correct);
#endif
         }
#ifdef LOCAL_LEARNING
         ++m_nTrainingRound; // each training round is one transition-action
#endif
      } 
      
      m_Agenda->nextRound(); // move round
   }

   if (bTrain) {
      correctState.StandardFinish(); // pop the root that is left
      // then make sure that the correct item is stack top finally
      if ( *(m_Agenda->bestGenerator()) != correctState ) {
         TRACE("The best item is not the correct one")
         updateScoresForStates(m_Agenda->bestGenerator(), &correctState, 1, -1) ; 
         return ;
      }
   } 

   TRACE("Outputing sentence");
   m_Agenda->sortGenerators();
   for (int i=0; i<std::min(m_Agenda->generatorSize(), nBest); ++i) {
      pGenerator = m_Agenda->generator(i) ; 
      if (pGenerator) {
         pGenerator->GenerateTree( sentence , retval[i] ) ; 
         if (scores) scores[i] = pGenerator->score;
      }
   }
   TRACE("Done, the highest score is: " << m_Agenda->bestGenerator()->score ) ;
   TRACE("The total time spent: " << double(clock() - total_start_time)/CLOCKS_PER_SEC) ;
}
Exemplo n.º 6
0
void CConParser::work( const bool bTrain , const CTwoStringVector &sentence , CSentenceParsed *retval , const CSentenceParsed &correct , int nBest , SCORE_TYPE *scores ) {

   static CStateItem lattice[(MAX_SENTENCE_SIZE*(2+UNARY_MOVES)+2)*(AGENDA_SIZE+1)+1000000];
   static CStateItem *lattice_index[MAX_SENTENCE_SIZE*(2+UNARY_MOVES)+2+1000000];

#ifdef DEBUG
   clock_t total_start_time = clock();
#endif
   const int length = sentence.size() ; 

   const static CStateItem *pGenerator ;
   const static CStateItem *pBestGen;
   const static CStateItem *correctState ;
   static bool bCorrect ;  // used in learning for early update
   static int tmp_i, tmp_j;
   static CAction correct_action;
   static CScoredStateAction scored_correct_action;
   static bool correct_action_scored;
   static std::vector<CAction> actions; // actions to apply for a candidate
   static CAgendaSimple<CScoredStateAction> beam(AGENDA_SIZE);
   static CScoredStateAction scored_action; // used rank actions
   ASSERT(nBest=1, "currently only do 1 best parse");
   static unsigned index;
   static bool bSkipLast;
#ifdef SCALE
   bool bAllTerminated;
#endif

   static CPackedScoreType<SCORE_TYPE, CAction::MAX> packedscores;

   assert(length<MAX_SENTENCE_SIZE);

   TRACE("Initialising the decoding process ... ") ;
   // initialise word cache
   m_lCache.clear();
   m_lWordLen.clear();
   for ( tmp_i=0; tmp_i<length; tmp_i++ ) {
       if(sentence[tmp_i].first != "-NONE-")
       {
           m_lCache.push_back( CTaggedWord<CTag, TAG_SEPARATOR>(sentence[tmp_i].first , sentence[tmp_i].second) );
           m_lWordLen.push_back( getUTF8StringLength(sentence[tmp_i].first) );
       }
   }
   // initialise agenda
   lattice_index[0] = lattice;
   lattice_index[0]->clear();
   lattice_index[0]->m_lCache = &m_lCache;
   lattice_index[0]->m_lEmptyWords = m_lEmptyWords;
#ifdef TRAIN_LOSS
   lattice_index[0]->bTrain = m_bTrain;
   getLabeledBrackets(correct, lattice_index[0]->gold_lb);
   TRACE(lattice_index[0]->gold_lb << std::endl);
#endif

#ifndef EARLY_UPDATE
   if (bTrain) bSkipLast = false;
#endif
   lattice_index[1] = lattice+1;
   if (bTrain) { 
      correctState = lattice_index[0];
   }
   index=0;

   TRACE("Decoding start ... ") ;
   while (true) { // for each step

      ++index;
      lattice_index[index+1] = lattice_index[index];
         
      beam.clear();

      pBestGen = 0;

      if (bTrain) {
         bCorrect = false;
         correctState->StandardMove(correct, correct_action);
         correct_action_scored = false;
      }

      for (pGenerator=lattice_index[index-1]; pGenerator!=lattice_index[index]; ++pGenerator) { // for each generator

#ifndef EARLY_UPDATE
         if (bTrain && bSkipLast && pGenerator == lattice_index[index]-1) {
            getOrUpdateStackScore(static_cast<CWeight*>(m_weights), packedscores, pGenerator);
            scored_correct_action.load(correct_action, pGenerator, packedscores[correct_action.code()]);
            correct_action_scored = true;
            break;
         }
#endif

         // load context
         m_Context.load(pGenerator, m_lCache, m_lWordLen, false);
   
         // get actions
         m_rule.getActions(*pGenerator, actions);

         if (actions.size() > 0)
            getOrUpdateStackScore(static_cast<CWeight*>(m_weights), packedscores, pGenerator);
         
         for (tmp_j=0; tmp_j<actions.size(); ++tmp_j) {
            scored_action.load(actions[tmp_j], pGenerator, packedscores[actions[tmp_j].code()]);
            beam.insertItem(&scored_action);
            if (bTrain && pGenerator == correctState && actions[tmp_j] == correct_action) {
               scored_correct_action = scored_action;
               correct_action_scored = true;
            }
         }
   
      } // done iterating generator item

#ifdef SCALE
      bAllTerminated = true;
#endif
      // insertItems
      for (tmp_j=0; tmp_j<beam.size(); ++tmp_j) { // insert from
         pGenerator = beam.item(tmp_j)->item;
         pGenerator->Move(lattice_index[index+1], beam.item(tmp_j)->action);
         lattice_index[index+1]->score = beam.item(tmp_j)->score;
#ifdef SCALE
         if ( ! lattice_index[index+1]->IsTerminated() )
            bAllTerminated = false;
#endif

         if ( pBestGen == 0 || lattice_index[index+1]->score > pBestGen->score ) {
            pBestGen = lattice_index[index+1];
         }

         // update bestgen
         if (bTrain) {
            if ( pGenerator == correctState && beam.item(tmp_j)->action == correct_action ) {
               correctState = lattice_index[index+1];
               assert (correctState->unaryreduces()<=UNARY_MOVES) ; 
               bCorrect = true;
            }
         }
         ++lattice_index[index+1];
      }

#ifdef SCALE
      if (bAllTerminated)
         break; // while
#else
      if (pBestGen->IsTerminated())
         break; // while
#endif

      // update items if correct item jump out of the agenda
      if (bTrain) { 
         if (!bCorrect ) {
            // note that if bCorrect == true then the correct state has 
            // already been updated, and the new value is one of the new states
            // among the newly produced from lattice[index+1].
            correctState->Move(lattice_index[index+1], correct_action); 
            correctState = lattice_index[index+1];
            lattice_index[index+1]->score = scored_correct_action.score;
            ++lattice_index[index+1];
            assert(correct_action_scored); // scored_correct_act valid
#ifdef EARLY_UPDATE
//         if (!bCorrect ) {
            TRACE("Error at the "<<correctState->current_word<<"th word; total is "<<m_lCache.size())
            // update
#ifdef TRAIN_MULTI
            updateScoresForMultipleStates(lattice_index[index], lattice_index[index+1], candidate_outout, correctState) ; 
#else
            // trace
            correctState->trace(&sentence);
            pBestGen->trace(&sentence);
//            updateScoresByLoss(pBestGen, correctState) ; 
            updateScoresForStates(pBestGen, correctState) ; 
#endif // TRAIN_MULTI
            return ;
//         } // bCorrect
#else // EARLY UDPATE
            bSkipLast = true;
#endif
         } // bCorrect
      }  // bTrain
   } // while

   if (bTrain) {
      // make sure that the correct item is stack top finally
      if ( pBestGen != correctState ) {
         if (!bCorrect) {
            correctState->Move(lattice_index[index+1], correct_action); 
            correctState = lattice_index[index+1];
            lattice_index[index+1]->score = scored_correct_action.score;
            assert(correct_action_scored); // scored_correct_act valid
         }
         TRACE("The best item is not the correct one")
#ifdef TRAIN_MULTI
         updateScoresForMultipleStates(lattice_index[index], lattice_index[index+1], pBestGen, correctState) ; 
#else // TRAIN_MULTI
         correctState->trace(&sentence);
         pBestGen->trace(&sentence);
//         updateScoresByLoss(pBestGen, correctState) ; 
         updateScoresForStates(pBestGen, correctState) ; 
#endif // TRAIN_MULTI
         return ;
      }
      else {
         TRACE("correct");
         correctState->trace(&sentence);
         pBestGen->trace(&sentence);
      }
   } 

   if (!retval) 
      return;

   TRACE("Outputing sentence");
   pBestGen->GenerateTree( sentence, retval[0] );
   if (scores) scores[0] = pBestGen->score;

   TRACE("Done, the highest score is: " << pBestGen->score ) ;
   TRACE("The total time spent: " << double(clock() - total_start_time)/CLOCKS_PER_SEC) ;
}
Exemplo n.º 7
0
int
CDepParser::work(const bool is_train,
                 const CTwoStringVector & sentence,
                 CDependencyParse * retval0, CDependencyParse * retval1,
                 const CDependencyParse & oracle_tree0, const CDependencyParse & oracle_tree1,
                 int nbest,
                 SCORE_TYPE *scores) {

#ifdef DEBUG
  clock_t total_start_time = clock();
#endif

  const int length = sentence.size();
  const int max_round = length * 4 + 1;
  const int max_lattice_size = (kAgendaSize + 1) * max_round;

  ASSERT(length < MAX_SENTENCE_SIZE,
         "The size of sentence is too long.");

  CStateItem * lattice = GetLattice(max_lattice_size);
  CStateItem * lattice_wrapper[max_lattice_size];
  CStateItem ** lattice_index[max_round];
  CStateItem * correct_state = lattice;



  for (int i = 0; i < max_lattice_size; ++ i) {
    lattice_wrapper[i] = lattice + i;
    lattice[i].len_ = length;
  }

  lattice[0].clear();
  correct_state = lattice;
  lattice_index[0] = lattice_wrapper;
  lattice_index[1] = lattice_index[0] + 1;

  static CPackedScoreType<SCORE_TYPE, action::kMax> packed_scores;


  TRACE("Initialising the decoding process ...");

  m_lCache.clear();
  for (int i = 0; i < length; ++ i) {
    m_lCache.push_back(CTaggedWord<CTag, TAG_SEPARATOR>(sentence[i].first,
                                                        sentence[i].second));
#ifdef LABELED
    if (is_train) {
      if (i == 0) { m_lCacheLabel0.clear();  m_lCacheLabel1.clear(); }
      m_lCacheLabel0.push_back(CDependencyLabel(oracle_tree0[i].label));
      m_lCacheLabel1.push_back(CDependencyLabel(oracle_tree1[i].label));
    }
#endif
  }

  int num_results = 0;
  int round = 0;
  bool is_correct; // used for training to specify correct state in lattice

  // loop with the next word to process in the sentence,
  // `round` represent the generators, and the condidates should be inserted
  // into the `round + 1`
  for (round = 1; round < max_round; ++ round) {
    if (lattice_index[round - 1] == lattice_index[round]) {
      // there is nothing in generators, the proning has cut all legel
      // generator. actually, in this kind of case, we should raise a
      // exception. however to achieve a parsing tree, an alternative
      // solution is go back to the previous round
      WARNING("Parsing Failed!");
      -- round;
      break;
    }

    int current_beam_size = 0;
    // loop over the generator states
    // std::cout << "round : " << round << std::endl;
    for (CStateItem ** q = lattice_index[round - 1];
        q != lattice_index[round];
        ++ q) {
      const CStateItem * generator = (*q);
      m_Beam->clear(); packed_scores.reset();

      GetOrUpdateStackScore(generator, packed_scores, action::kNoAction);


      Transit(generator, packed_scores);

      for (unsigned i = 0; i < m_Beam->size(); ++ i) {
        CStateItem candidate; candidate = (*generator);
        // generate candidate state according to the states in beam
        int curIndex = candidate.nextactionindex();
        candidate.Move(curIndex, m_Beam->item(i)->action);
        candidate.score = m_Beam->item(i)->score;
        candidate.previous_ = generator;
        current_beam_size += InsertIntoBeam(lattice_index[round],
                                            &candidate,
                                            current_beam_size,
                                            kAgendaSize);
      }
    }


    lattice_index[round + 1] = lattice_index[round] + current_beam_size;

    if (is_train) {
        CStateItem next_correct_state(*correct_state);
      unsigned goldaction = next_correct_state.StandardMoveStep(oracle_tree0, oracle_tree1
#ifdef LABELED
          , m_lCacheLabel0, m_lCacheLabel1
#endif // end for LABELED
          );

      //std::cout << *correct_state << std::endl;
      //std::cout << goldaction << std::endl;

      next_correct_state.previous_ = correct_state;
      is_correct = false;

      for (CStateItem ** q = lattice_index[round];
           q != lattice_index[round + 1];
          ++ q) {

        CStateItem * p = *q;
        if (next_correct_state.last_action_index == p->last_action_index
            && next_correct_state.last_action[next_correct_state.last_action_index] == p->last_action[p->last_action_index]
             && p->previous_ == correct_state) {
          correct_state = p;
          is_correct = true;
          break;
        }
      }



      //std::cout << *correct_state << std::endl;
      //std::cout << goldaction << std::endl;

#ifdef EARLY_UPDATE
      if (!is_correct || round == max_round-1) {
        int curIndex = next_correct_state.nextactionindex();
        TRACE("ERROR at the " << next_correct_state.size() << "th word for schema " << curIndex);
        if(curIndex == 0)
        {
            TRACE(" Total is " << oracle_tree0.size());
        }
        else
        {
            TRACE(" Total is " << oracle_tree1.size());
        }

        CStateItem * best_generator = (*lattice_index[round]);
        for (CStateItem ** q = lattice_index[round];
             q != lattice_index[round + 1];
              ++ q) {
          CStateItem * p = (*q);
          if (best_generator->score < p->score) {
            best_generator = p;
          }
        }
        UpdateScoresForStates(best_generator, &next_correct_state, 1, -1);
        return -1;
      }
#endif // end for EARLY_UPDATE

    }
  }

//   if (is_train) {
//      CStateItem * best_generator = (*lattice_index[round-1]);
//        for (CStateItem ** q = lattice_index[round-1]; q != lattice_index[round]; ++ q) {
//           CStateItem * p = (*q);
//          if (best_generator->score < p->score) {
//               best_generator = p;
//            }
//        }
//        if (best_generator != correct_state) {
//            UpdateScoresForStates(best_generator, correct_state, 1, -1);
//        }
//        return -1;
//    }
  //delete[] sequence_correct_state;


/*
  if (is_train) {
      //correct_state->StandardFinish(); // pop the root that is left
     // then make sure that the correct item is stack top finally
      CStateItem * best_generator = (*lattice_index[round-1]);
      for (CStateItem ** q = lattice_index[round-1];
           q != lattice_index[round ];
           ++ q) {
        CStateItem * p = (*q);
        if (best_generator->score < p->score) {
          best_generator = p;
        }
      }

     {
        //TRACE("The best item is not the correct one")
        UpdateScoresForStates(best_generator, correct_state, 1, -1) ;
     }
  }
*/
  if (!retval0 || !retval1) {
    return -1;
  }

  TRACE("Output sentence");
  std::sort(lattice_index[round - 1], lattice_index[round], StateHeapMore);
  num_results = lattice_index[round] - lattice_index[round - 1];

  for (int i = 0; i < std::min(num_results, nbest); ++ i) {
    assert( (*(lattice_index[round - 1] + i))->size() == m_lCache.size());
    (*(lattice_index[round - 1] + i))->GenerateTree(sentence, retval0[i], retval1[i]);
    if (scores) { scores[i] = (*(lattice_index[round - 1] + i))->score; }
  }
  TRACE("Done, total time spent: " << double(clock() - total_start_time) / CLOCKS_PER_SEC);
  return num_results;
}
Exemplo n.º 8
0
/*---------------------------------------------------------------
 *
 * work - the working process shared by training and parsing
 *
 * Returns: makes a new instance of CDependencyParse
 *
 *--------------------------------------------------------------*/
int
CDepParser::work(const bool is_train,
                 const CTwoStringVector & sentence,
                 CDependencyParse * retval,
                 const CDependencyParse & oracle_tree,
                 int nbest,
                 SCORE_TYPE *scores) {
#ifdef DEBUG
    clock_t total_start_time = clock();
#endif

    const int length = sentence.size();
    const int max_round = length * 2 + 1;
    const int max_lattice_size = (kAgendaSize + 1) * max_round;

    ASSERT(length < MAX_SENTENCE_SIZE,
           "The size of sentence is too long.");

    CStateItem * lattice = GetLattice(max_lattice_size);
    CStateItem * lattice_index[max_round];
    CStateItem * correct_state = lattice;

    for (int i = 0; i < max_lattice_size; ++ i) {
        lattice[i].len_ = length;
    }

    lattice[0].clear();
    correct_state = lattice;
    lattice_index[0] = lattice;
    lattice_index[1] = lattice_index[0] + 1;

    static CPackedScore packed_scores;
    TRACE("Initialising the decoding process ...");

    m_lCache.clear();
    for (int i = 0; i < length; ++ i) {
        m_lCache.push_back(CTaggedWord<CTag, TAG_SEPARATOR>(sentence[i].first,
                           sentence[i].second));
#ifdef LABELED
        if (is_train) {
            if (i == 0) {
                m_lCacheLabel.clear();
            }
            m_lCacheLabel.push_back(CDependencyLabel(oracle_tree[i].label));
        }
#endif
    }

    int num_results = 0;
    int round = 0;
    bool is_correct; // used for training to specify correct state in lattice

    // loop with the next word to process in the sentence, 'round' represent the
    // generators, and the condidates should be inserted into the 'round + 1'
    for (round = 1; round < max_round; ++ round) {
        if (lattice_index[round - 1] == lattice_index[round]) {
            // There is nothing in generators, the proning has cut all legel
            // generator. Actually, in this kind of case, we should raise a
            // exception. However to achieve a parsing tree, an alternative
            // solution is go back to the previous round
            WARNING("Parsing Failed!");
            -- round;
            break;
        }

        current_beam_size_ = 0;
        // loop over the generator states
        // std::cout << "round : " << round << std::endl;
        for (CStateItem * q = lattice_index[round - 1]; q != lattice_index[round];
                ++ q) {
            const CStateItem * generator = q;
            packed_scores.reset();
            GetOrUpdateStackScore(generator, packed_scores, action::kNoAction);
            Transit(generator, packed_scores);
        }

        for (unsigned i = 0; i < current_beam_size_; ++ i) {
            const CScoredTransition& transition = m_kBestTransitions[i];
            CStateItem* target = lattice_index[round]+ i;
            (*target) = (*transition.source);
            // generate candidate state according to the states in beam
            target->Move(transition.action);
            target->score = transition.score;
            target->previous_ = transition.source;
        }

        lattice_index[round + 1] = lattice_index[round] + current_beam_size_;

        if (is_train) {
            CStateItem next_correct_state(*correct_state);
            next_correct_state.StandardMoveStep(oracle_tree
#ifdef LABELED
                                                , m_lCacheLabel
#endif // end for LABELED
                                               );

            next_correct_state.previous_ = correct_state;
            is_correct = false;

            for (CStateItem *p = lattice_index[round]; p != lattice_index[round + 1];
                    ++ p) {
                if (next_correct_state.last_action == p->last_action
                        && p->previous_ == correct_state) {
                    correct_state = p;
                    is_correct = true;
                    break;
                }
            }

#ifdef EARLY_UPDATE
            if (!is_correct) {
                TRACE("ERROR at the " << next_correct_state.size() << "th word;"
                      << " Total is " << oracle_tree.size());

                CStateItem * best_generator = lattice_index[round];
                for (CStateItem * p = lattice_index[round]; p != lattice_index[round + 1];
                        ++ p) {
                    if (best_generator->score < p->score) {
                        best_generator = p;
                    }
                }
                UpdateScoresForStates(best_generator, &next_correct_state, 1, -1);
                return -1;
            }
#endif // end for EARLY_UPDATE
        }
    }

    if (is_train) {
        CStateItem * best_generator = lattice_index[round-1];

        for (CStateItem * p = lattice_index[round-1]; p != lattice_index[round]; ++ p) {
            if (best_generator->score < p->score) {
                best_generator = p;
            }
        }
        if (best_generator != correct_state) {
            UpdateScoresForStates(best_generator, correct_state, 1, -1);
        }
        return -1;
    }

    if (!retval) {
        return -1;
    }

    TRACE("Output sentence");
    std::sort(lattice_index[round - 1], lattice_index[round], StateMore);
    num_results = lattice_index[round] - lattice_index[round - 1];

    for (int i = 0; i < std::min(num_results, nbest); ++ i) {
        assert( (lattice_index[round - 1] + i)->size() == m_lCache.size());
        (lattice_index[round - 1] + i)->GenerateTree(sentence, retval[i]);
        if (scores) {
            scores[i] = (lattice_index[round - 1] + i)->score;
        }
    }
    TRACE("Done, total time spent: " << double(clock() - total_start_time) / CLOCKS_PER_SEC);
    return num_results;
}