void CConParser::getPositiveFeatures( const CSentenceParsed &correct ) { static CTwoStringVector sentence; static CStateItem states[MAX_SENTENCE_SIZE*(1+UNARY_MOVES)+2]; static CPackedScoreType<SCORE_TYPE, CAction::MAX> scores; static int current; static CAction action; states[0].clear(); current = 0; UnparseSentence( &correct, &sentence ) ; m_lCache.clear(); m_lWordLen.clear(); for (unsigned i=0; i<sentence.size(); ++i) { m_lCache.push_back( CTaggedWord<CTag, TAG_SEPARATOR>(sentence[i].first , sentence[i].second) ); m_lWordLen.push_back( getUTF8StringLength(sentence[i].first) ); } while ( !states[current].IsTerminated() ) { states[current].StandardMove(correct, action); //std::cout << action << std::endl; m_Context.load(states+current, m_lCache, m_lWordLen, true); getOrUpdateStackScore(static_cast<CWeight*>(m_weights), scores, states+current, action, 1, -1); states[current].Move(states+current+1, action); ++current; } }
void CConParser::work(const CTwoStringVector &sentence, const CSentenceParsed &correct, CCoNLLOutput *o_conll){ const int length = sentence.size(); static int tmp_i,tmp_j; static CAction correct_action; m_lCache.clear(); m_lWordLen.clear(); for (tmp_i=0; tmp_i<length; tmp_i++ ) { m_lCache.push_back( CTaggedWord<CTag, TAG_SEPARATOR>(sentence[tmp_i].first , sentence[tmp_i].second) ); m_lWordLen.push_back( getUTF8StringLength(sentence[tmp_i].first) ); } std::vector<CStateItem> p(MAX_SENTENCE_SIZE*(2+UNARY_MOVES)+2); CStateItem *correctState = &p[0]; //getLabeledBrackets(correct, correctState->gold_lb); correctState->clear(); correctState->words = (&m_lCache); tmp_i = 1; while(true){ correctState->StandardMove(correct, correct_action); //std::cerr<<correct_action<<std::endl; correctState->Move(&p[tmp_i],correct_action); correctState = &p[tmp_i]; tmp_i ++; if (correctState == 0 || correctState->IsTerminated()) break; // while } correctState->GenerateStanford(sentence, o_conll); p.clear(); }
bool TARGET_LANGUAGE::CTagger::train( const CTwoStringVector * correct ) { static int i; static CTwoStringVector tagged; static CStringVector sentence; UntagSentence( correct, &sentence ); ++m_nTrainingRound; updateTagDict(correct); tag( &sentence , &tagged , 1 , NULL ); if ( tagged != *correct ) { for (i=0; i<tagged.size(); ++i) updateLocalFeatureVector(eSubtract, &tagged, i, m_nTrainingRound); for (i=0; i<correct->size(); ++i) updateLocalFeatureVector(eAdd, correct, i, m_nTrainingRound); m_bScoreModified = true; return true; } return false; }
bool TARGET_LANGUAGE::CTagger::train( const CTwoStringVector * correct ) { static int i; static CTwoStringVector tagged; static CStringVector sentence; static bool bDicOOV; static unsigned long long possible_tags; // possible tags for a word static unsigned long long current_tag; UntagSentence( correct, &sentence ); ++m_nTrainingRound; m_CacheTags.clear(); for (i=0; i<correct->size(); ++i) m_CacheTags.push_back(CTag(correct->at(i).second).code()); tag( &sentence , &tagged , 1 , NULL ); if ( tagged != *correct ) { for (i=0; i<tagged.size(); ++i) { bDicOOV = false; if (m_TagDict) { possible_tags = getPossibleTagsForWord(correct->at(i).first); current_tag = (1LL<<CTag(correct->at(i).second).code()) ; if ( ( possible_tags & current_tag ) == 0 ) { WARNING("dictionary does not have the example word " << correct->at(i).first << " with tag " << correct->at(i).second); bDicOOV = true; } } // if (!bDicOOV) { updateLocalFeatureVector(eSubtract, &tagged, i, m_nTrainingRound); updateLocalFeatureVector(eAdd, correct, i, m_nTrainingRound); m_bScoreModified = true; // } } return true; } return false; }
void CDepParser::work( const bool bTrain , const CTwoStringVector &sentence , CDependencyParse *retval , const CDependencyParse &correct , int nBest , SCORE_TYPE *scores ) { #ifdef DEBUG clock_t total_start_time = clock(); #endif static int index; const int length = sentence.size() ; const CStateItem *pGenerator ; static CStateItem pCandidate(&m_lCache) ; // used only for training static bool bCorrect ; // used in learning for early update static bool bContradictsRules; static CStateItem correctState(&m_lCache) ; static CPackedScoreType<SCORE_TYPE, action::MAX> packed_scores; ASSERT(length<MAX_SENTENCE_SIZE, "The size of the sentence is larger than the system configuration."); TRACE("Initialising the decoding process...") ; // initialise word cache bContradictsRules = false; m_lCache.clear(); for ( index=0; index<length; ++index ) { m_lCache.push_back( CTaggedWord<CTag, TAG_SEPARATOR>(sentence[index].first , sentence[index].second) ); // filter std::cout training examples with rules if (bTrain && m_weights->rules()) { // the root if ( correct[index].head == DEPENDENCY_LINK_NO_HEAD && canBeRoot(m_lCache[index].tag.code())==false) { TRACE("Rule contradiction: " << m_lCache[index].tag.code() << " can be root."); bContradictsRules = true; } // head left if ( correct[index].head < index && hasLeftHead(m_lCache[index].tag.code())==false) { TRACE("Rule contradiction: " << m_lCache[index].tag.code() << " has left head."); bContradictsRules = true; } // head right if ( correct[index].head > index && hasRightHead(m_lCache[index].tag.code())==false) { TRACE("Rule contradiction: " << m_lCache[index].tag.code() << " has right head."); bContradictsRules = true; } } } // initialise agenda m_Agenda->clear(); pCandidate.clear(); // restore state using clean m_Agenda->pushCandidate(&pCandidate); // and push it back m_Agenda->nextRound(); // as the generator item if (bTrain) correctState.clear(); // verifying supertags if (m_supertags) { ASSERT(m_supertags->getSentenceSize()==length, "Sentence size does not match supertags size"); } #ifdef LABELED unsigned long label; m_lCacheLabel.clear(); if (bTrain) { for (index=0; index<length; ++index) { m_lCacheLabel.push_back(CDependencyLabel(correct[index].label)); if (m_weights->rules() && !canAssignLabel(m_lCache, correct[index].head, index, m_lCacheLabel[index])) { TRACE("Rule contradiction: " << correct[index].label << " on link head " << m_lCache[correct[index].head].tag.code() << " dep " << m_lCache[index].tag.code()); bContradictsRules = true; } } } #endif // skip the training example if contradicts if (bTrain && m_weights->rules() && bContradictsRules) { std::cout << "Skipping training example because it contradicts rules..." <<std::endl; return; } TRACE("Decoding started"); // loop with the next word to process in the sentence for (index=0; index<length*2; ++index) { if (bTrain) bCorrect = false ; // none can this find with pruning ??? if (m_Agenda->generatorSize() == 0) { WARNING("parsing failed"); return; } pGenerator = m_Agenda->generatorStart(); // iterate generators for (int j=0; j<m_Agenda->generatorSize(); ++j) { // for the state items that already contain all words m_Beam->clear(); packed_scores.reset(); getOrUpdateStackScore( pGenerator, packed_scores, action::NO_ACTION ); if ( pGenerator->size() == length ) { assert( pGenerator->stacksize() != 0 ); if ( pGenerator->stacksize()>1 ) { #ifdef FRAGMENTED_TREE if (pGenerator->head(pGenerator->stacktop()) == DEPENDENCY_LINK_NO_HEAD) poproot(pGenerator, packed_scores); else #endif reduce(pGenerator, packed_scores) ; } else { poproot(pGenerator, packed_scores); } } // for the state items that still need more words else { if ( !pGenerator->afterreduce() ) { // there are many ways when there are many arcrighted items on the stack and the root need arcleft. force this. if ( #ifndef FRAGMENTED_TREE ( pGenerator->size() < length-1 || pGenerator->stackempty() ) && // keep only one global root #endif ( pGenerator->stackempty() || m_supertags == 0 || m_supertags->canShift( pGenerator->size() ) ) && // supertags ( pGenerator->stackempty() || !m_weights->rules() || canBeRoot( m_lCache[pGenerator->size()].tag.code() ) || hasRightHead(m_lCache[pGenerator->size()].tag.code()) ) // rules ) { shift(pGenerator, packed_scores) ; } } if ( !pGenerator->stackempty() ) { if ( #ifndef FRAGMENTED_TREE ( pGenerator->size() < length-1 || pGenerator->headstacksize() == 1 ) && // one root #endif ( m_supertags == 0 || m_supertags->canArcRight(pGenerator->stacktop(), pGenerator->size()) ) && // supertags conform to this action ( !m_weights->rules() || hasLeftHead(m_lCache[pGenerator->size()].tag.code()) ) // rules ) { arcright(pGenerator, packed_scores) ; } } if ( (!m_bCoNLL && !pGenerator->stackempty()) || (m_bCoNLL && pGenerator->stacksize()>1) // make sure that for conll the first item is not popped ) { if ( pGenerator->head( pGenerator->stacktop() ) != DEPENDENCY_LINK_NO_HEAD ) { reduce(pGenerator, packed_scores) ; } else { if ( (m_supertags == 0 || m_supertags->canArcLeft(pGenerator->size(), pGenerator->stacktop())) && // supertags (!m_weights->rules() || hasRightHead(m_lCache[pGenerator->stacktop()].tag.code())) // rules ) { arcleft(pGenerator, packed_scores) ; } } } } // insert item for (unsigned i=0; i<m_Beam->size(); ++i) { pCandidate = *pGenerator; pCandidate.score = m_Beam->item(i)->score; pCandidate.Move( m_Beam->item(i)->action ); m_Agenda->pushCandidate(&pCandidate); } if (bTrain && *pGenerator == correctState) { bCorrect = true ; } pGenerator = m_Agenda->generatorNext() ; } // when we are doing training, we need to consider the standard move and update if (bTrain) { #ifdef EARLY_UPDATE if (!bCorrect) { TRACE("Error at the "<<correctState.size()<<"th word; total is "<<correct.size()) updateScoresForStates(m_Agenda->bestGenerator(), &correctState, 1, -1) ; #ifndef LOCAL_LEARNING return ; #else m_Agenda->clearCandidates(); m_Agenda->pushCandidate(&correctState); #endif } #endif if (bCorrect) { #ifdef LABELED correctState.StandardMoveStep(correct, m_lCacheLabel); #else correctState.StandardMoveStep(correct); #endif } #ifdef LOCAL_LEARNING ++m_nTrainingRound; // each training round is one transition-action #endif } m_Agenda->nextRound(); // move round } if (bTrain) { correctState.StandardFinish(); // pop the root that is left // then make sure that the correct item is stack top finally if ( *(m_Agenda->bestGenerator()) != correctState ) { TRACE("The best item is not the correct one") updateScoresForStates(m_Agenda->bestGenerator(), &correctState, 1, -1) ; return ; } } TRACE("Outputing sentence"); m_Agenda->sortGenerators(); for (int i=0; i<std::min(m_Agenda->generatorSize(), nBest); ++i) { pGenerator = m_Agenda->generator(i) ; if (pGenerator) { pGenerator->GenerateTree( sentence , retval[i] ) ; if (scores) scores[i] = pGenerator->score; } } TRACE("Done, the highest score is: " << m_Agenda->bestGenerator()->score ) ; TRACE("The total time spent: " << double(clock() - total_start_time)/CLOCKS_PER_SEC) ; }
void CConParser::work( const bool bTrain , const CTwoStringVector &sentence , CSentenceParsed *retval , const CSentenceParsed &correct , int nBest , SCORE_TYPE *scores ) { static CStateItem lattice[(MAX_SENTENCE_SIZE*(2+UNARY_MOVES)+2)*(AGENDA_SIZE+1)+1000000]; static CStateItem *lattice_index[MAX_SENTENCE_SIZE*(2+UNARY_MOVES)+2+1000000]; #ifdef DEBUG clock_t total_start_time = clock(); #endif const int length = sentence.size() ; const static CStateItem *pGenerator ; const static CStateItem *pBestGen; const static CStateItem *correctState ; static bool bCorrect ; // used in learning for early update static int tmp_i, tmp_j; static CAction correct_action; static CScoredStateAction scored_correct_action; static bool correct_action_scored; static std::vector<CAction> actions; // actions to apply for a candidate static CAgendaSimple<CScoredStateAction> beam(AGENDA_SIZE); static CScoredStateAction scored_action; // used rank actions ASSERT(nBest=1, "currently only do 1 best parse"); static unsigned index; static bool bSkipLast; #ifdef SCALE bool bAllTerminated; #endif static CPackedScoreType<SCORE_TYPE, CAction::MAX> packedscores; assert(length<MAX_SENTENCE_SIZE); TRACE("Initialising the decoding process ... ") ; // initialise word cache m_lCache.clear(); m_lWordLen.clear(); for ( tmp_i=0; tmp_i<length; tmp_i++ ) { if(sentence[tmp_i].first != "-NONE-") { m_lCache.push_back( CTaggedWord<CTag, TAG_SEPARATOR>(sentence[tmp_i].first , sentence[tmp_i].second) ); m_lWordLen.push_back( getUTF8StringLength(sentence[tmp_i].first) ); } } // initialise agenda lattice_index[0] = lattice; lattice_index[0]->clear(); lattice_index[0]->m_lCache = &m_lCache; lattice_index[0]->m_lEmptyWords = m_lEmptyWords; #ifdef TRAIN_LOSS lattice_index[0]->bTrain = m_bTrain; getLabeledBrackets(correct, lattice_index[0]->gold_lb); TRACE(lattice_index[0]->gold_lb << std::endl); #endif #ifndef EARLY_UPDATE if (bTrain) bSkipLast = false; #endif lattice_index[1] = lattice+1; if (bTrain) { correctState = lattice_index[0]; } index=0; TRACE("Decoding start ... ") ; while (true) { // for each step ++index; lattice_index[index+1] = lattice_index[index]; beam.clear(); pBestGen = 0; if (bTrain) { bCorrect = false; correctState->StandardMove(correct, correct_action); correct_action_scored = false; } for (pGenerator=lattice_index[index-1]; pGenerator!=lattice_index[index]; ++pGenerator) { // for each generator #ifndef EARLY_UPDATE if (bTrain && bSkipLast && pGenerator == lattice_index[index]-1) { getOrUpdateStackScore(static_cast<CWeight*>(m_weights), packedscores, pGenerator); scored_correct_action.load(correct_action, pGenerator, packedscores[correct_action.code()]); correct_action_scored = true; break; } #endif // load context m_Context.load(pGenerator, m_lCache, m_lWordLen, false); // get actions m_rule.getActions(*pGenerator, actions); if (actions.size() > 0) getOrUpdateStackScore(static_cast<CWeight*>(m_weights), packedscores, pGenerator); for (tmp_j=0; tmp_j<actions.size(); ++tmp_j) { scored_action.load(actions[tmp_j], pGenerator, packedscores[actions[tmp_j].code()]); beam.insertItem(&scored_action); if (bTrain && pGenerator == correctState && actions[tmp_j] == correct_action) { scored_correct_action = scored_action; correct_action_scored = true; } } } // done iterating generator item #ifdef SCALE bAllTerminated = true; #endif // insertItems for (tmp_j=0; tmp_j<beam.size(); ++tmp_j) { // insert from pGenerator = beam.item(tmp_j)->item; pGenerator->Move(lattice_index[index+1], beam.item(tmp_j)->action); lattice_index[index+1]->score = beam.item(tmp_j)->score; #ifdef SCALE if ( ! lattice_index[index+1]->IsTerminated() ) bAllTerminated = false; #endif if ( pBestGen == 0 || lattice_index[index+1]->score > pBestGen->score ) { pBestGen = lattice_index[index+1]; } // update bestgen if (bTrain) { if ( pGenerator == correctState && beam.item(tmp_j)->action == correct_action ) { correctState = lattice_index[index+1]; assert (correctState->unaryreduces()<=UNARY_MOVES) ; bCorrect = true; } } ++lattice_index[index+1]; } #ifdef SCALE if (bAllTerminated) break; // while #else if (pBestGen->IsTerminated()) break; // while #endif // update items if correct item jump out of the agenda if (bTrain) { if (!bCorrect ) { // note that if bCorrect == true then the correct state has // already been updated, and the new value is one of the new states // among the newly produced from lattice[index+1]. correctState->Move(lattice_index[index+1], correct_action); correctState = lattice_index[index+1]; lattice_index[index+1]->score = scored_correct_action.score; ++lattice_index[index+1]; assert(correct_action_scored); // scored_correct_act valid #ifdef EARLY_UPDATE // if (!bCorrect ) { TRACE("Error at the "<<correctState->current_word<<"th word; total is "<<m_lCache.size()) // update #ifdef TRAIN_MULTI updateScoresForMultipleStates(lattice_index[index], lattice_index[index+1], candidate_outout, correctState) ; #else // trace correctState->trace(&sentence); pBestGen->trace(&sentence); // updateScoresByLoss(pBestGen, correctState) ; updateScoresForStates(pBestGen, correctState) ; #endif // TRAIN_MULTI return ; // } // bCorrect #else // EARLY UDPATE bSkipLast = true; #endif } // bCorrect } // bTrain } // while if (bTrain) { // make sure that the correct item is stack top finally if ( pBestGen != correctState ) { if (!bCorrect) { correctState->Move(lattice_index[index+1], correct_action); correctState = lattice_index[index+1]; lattice_index[index+1]->score = scored_correct_action.score; assert(correct_action_scored); // scored_correct_act valid } TRACE("The best item is not the correct one") #ifdef TRAIN_MULTI updateScoresForMultipleStates(lattice_index[index], lattice_index[index+1], pBestGen, correctState) ; #else // TRAIN_MULTI correctState->trace(&sentence); pBestGen->trace(&sentence); // updateScoresByLoss(pBestGen, correctState) ; updateScoresForStates(pBestGen, correctState) ; #endif // TRAIN_MULTI return ; } else { TRACE("correct"); correctState->trace(&sentence); pBestGen->trace(&sentence); } } if (!retval) return; TRACE("Outputing sentence"); pBestGen->GenerateTree( sentence, retval[0] ); if (scores) scores[0] = pBestGen->score; TRACE("Done, the highest score is: " << pBestGen->score ) ; TRACE("The total time spent: " << double(clock() - total_start_time)/CLOCKS_PER_SEC) ; }
int CDepParser::work(const bool is_train, const CTwoStringVector & sentence, CDependencyParse * retval0, CDependencyParse * retval1, const CDependencyParse & oracle_tree0, const CDependencyParse & oracle_tree1, int nbest, SCORE_TYPE *scores) { #ifdef DEBUG clock_t total_start_time = clock(); #endif const int length = sentence.size(); const int max_round = length * 4 + 1; const int max_lattice_size = (kAgendaSize + 1) * max_round; ASSERT(length < MAX_SENTENCE_SIZE, "The size of sentence is too long."); CStateItem * lattice = GetLattice(max_lattice_size); CStateItem * lattice_wrapper[max_lattice_size]; CStateItem ** lattice_index[max_round]; CStateItem * correct_state = lattice; for (int i = 0; i < max_lattice_size; ++ i) { lattice_wrapper[i] = lattice + i; lattice[i].len_ = length; } lattice[0].clear(); correct_state = lattice; lattice_index[0] = lattice_wrapper; lattice_index[1] = lattice_index[0] + 1; static CPackedScoreType<SCORE_TYPE, action::kMax> packed_scores; TRACE("Initialising the decoding process ..."); m_lCache.clear(); for (int i = 0; i < length; ++ i) { m_lCache.push_back(CTaggedWord<CTag, TAG_SEPARATOR>(sentence[i].first, sentence[i].second)); #ifdef LABELED if (is_train) { if (i == 0) { m_lCacheLabel0.clear(); m_lCacheLabel1.clear(); } m_lCacheLabel0.push_back(CDependencyLabel(oracle_tree0[i].label)); m_lCacheLabel1.push_back(CDependencyLabel(oracle_tree1[i].label)); } #endif } int num_results = 0; int round = 0; bool is_correct; // used for training to specify correct state in lattice // loop with the next word to process in the sentence, // `round` represent the generators, and the condidates should be inserted // into the `round + 1` for (round = 1; round < max_round; ++ round) { if (lattice_index[round - 1] == lattice_index[round]) { // there is nothing in generators, the proning has cut all legel // generator. actually, in this kind of case, we should raise a // exception. however to achieve a parsing tree, an alternative // solution is go back to the previous round WARNING("Parsing Failed!"); -- round; break; } int current_beam_size = 0; // loop over the generator states // std::cout << "round : " << round << std::endl; for (CStateItem ** q = lattice_index[round - 1]; q != lattice_index[round]; ++ q) { const CStateItem * generator = (*q); m_Beam->clear(); packed_scores.reset(); GetOrUpdateStackScore(generator, packed_scores, action::kNoAction); Transit(generator, packed_scores); for (unsigned i = 0; i < m_Beam->size(); ++ i) { CStateItem candidate; candidate = (*generator); // generate candidate state according to the states in beam int curIndex = candidate.nextactionindex(); candidate.Move(curIndex, m_Beam->item(i)->action); candidate.score = m_Beam->item(i)->score; candidate.previous_ = generator; current_beam_size += InsertIntoBeam(lattice_index[round], &candidate, current_beam_size, kAgendaSize); } } lattice_index[round + 1] = lattice_index[round] + current_beam_size; if (is_train) { CStateItem next_correct_state(*correct_state); unsigned goldaction = next_correct_state.StandardMoveStep(oracle_tree0, oracle_tree1 #ifdef LABELED , m_lCacheLabel0, m_lCacheLabel1 #endif // end for LABELED ); //std::cout << *correct_state << std::endl; //std::cout << goldaction << std::endl; next_correct_state.previous_ = correct_state; is_correct = false; for (CStateItem ** q = lattice_index[round]; q != lattice_index[round + 1]; ++ q) { CStateItem * p = *q; if (next_correct_state.last_action_index == p->last_action_index && next_correct_state.last_action[next_correct_state.last_action_index] == p->last_action[p->last_action_index] && p->previous_ == correct_state) { correct_state = p; is_correct = true; break; } } //std::cout << *correct_state << std::endl; //std::cout << goldaction << std::endl; #ifdef EARLY_UPDATE if (!is_correct || round == max_round-1) { int curIndex = next_correct_state.nextactionindex(); TRACE("ERROR at the " << next_correct_state.size() << "th word for schema " << curIndex); if(curIndex == 0) { TRACE(" Total is " << oracle_tree0.size()); } else { TRACE(" Total is " << oracle_tree1.size()); } CStateItem * best_generator = (*lattice_index[round]); for (CStateItem ** q = lattice_index[round]; q != lattice_index[round + 1]; ++ q) { CStateItem * p = (*q); if (best_generator->score < p->score) { best_generator = p; } } UpdateScoresForStates(best_generator, &next_correct_state, 1, -1); return -1; } #endif // end for EARLY_UPDATE } } // if (is_train) { // CStateItem * best_generator = (*lattice_index[round-1]); // for (CStateItem ** q = lattice_index[round-1]; q != lattice_index[round]; ++ q) { // CStateItem * p = (*q); // if (best_generator->score < p->score) { // best_generator = p; // } // } // if (best_generator != correct_state) { // UpdateScoresForStates(best_generator, correct_state, 1, -1); // } // return -1; // } //delete[] sequence_correct_state; /* if (is_train) { //correct_state->StandardFinish(); // pop the root that is left // then make sure that the correct item is stack top finally CStateItem * best_generator = (*lattice_index[round-1]); for (CStateItem ** q = lattice_index[round-1]; q != lattice_index[round ]; ++ q) { CStateItem * p = (*q); if (best_generator->score < p->score) { best_generator = p; } } { //TRACE("The best item is not the correct one") UpdateScoresForStates(best_generator, correct_state, 1, -1) ; } } */ if (!retval0 || !retval1) { return -1; } TRACE("Output sentence"); std::sort(lattice_index[round - 1], lattice_index[round], StateHeapMore); num_results = lattice_index[round] - lattice_index[round - 1]; for (int i = 0; i < std::min(num_results, nbest); ++ i) { assert( (*(lattice_index[round - 1] + i))->size() == m_lCache.size()); (*(lattice_index[round - 1] + i))->GenerateTree(sentence, retval0[i], retval1[i]); if (scores) { scores[i] = (*(lattice_index[round - 1] + i))->score; } } TRACE("Done, total time spent: " << double(clock() - total_start_time) / CLOCKS_PER_SEC); return num_results; }
/*--------------------------------------------------------------- * * work - the working process shared by training and parsing * * Returns: makes a new instance of CDependencyParse * *--------------------------------------------------------------*/ int CDepParser::work(const bool is_train, const CTwoStringVector & sentence, CDependencyParse * retval, const CDependencyParse & oracle_tree, int nbest, SCORE_TYPE *scores) { #ifdef DEBUG clock_t total_start_time = clock(); #endif const int length = sentence.size(); const int max_round = length * 2 + 1; const int max_lattice_size = (kAgendaSize + 1) * max_round; ASSERT(length < MAX_SENTENCE_SIZE, "The size of sentence is too long."); CStateItem * lattice = GetLattice(max_lattice_size); CStateItem * lattice_index[max_round]; CStateItem * correct_state = lattice; for (int i = 0; i < max_lattice_size; ++ i) { lattice[i].len_ = length; } lattice[0].clear(); correct_state = lattice; lattice_index[0] = lattice; lattice_index[1] = lattice_index[0] + 1; static CPackedScore packed_scores; TRACE("Initialising the decoding process ..."); m_lCache.clear(); for (int i = 0; i < length; ++ i) { m_lCache.push_back(CTaggedWord<CTag, TAG_SEPARATOR>(sentence[i].first, sentence[i].second)); #ifdef LABELED if (is_train) { if (i == 0) { m_lCacheLabel.clear(); } m_lCacheLabel.push_back(CDependencyLabel(oracle_tree[i].label)); } #endif } int num_results = 0; int round = 0; bool is_correct; // used for training to specify correct state in lattice // loop with the next word to process in the sentence, 'round' represent the // generators, and the condidates should be inserted into the 'round + 1' for (round = 1; round < max_round; ++ round) { if (lattice_index[round - 1] == lattice_index[round]) { // There is nothing in generators, the proning has cut all legel // generator. Actually, in this kind of case, we should raise a // exception. However to achieve a parsing tree, an alternative // solution is go back to the previous round WARNING("Parsing Failed!"); -- round; break; } current_beam_size_ = 0; // loop over the generator states // std::cout << "round : " << round << std::endl; for (CStateItem * q = lattice_index[round - 1]; q != lattice_index[round]; ++ q) { const CStateItem * generator = q; packed_scores.reset(); GetOrUpdateStackScore(generator, packed_scores, action::kNoAction); Transit(generator, packed_scores); } for (unsigned i = 0; i < current_beam_size_; ++ i) { const CScoredTransition& transition = m_kBestTransitions[i]; CStateItem* target = lattice_index[round]+ i; (*target) = (*transition.source); // generate candidate state according to the states in beam target->Move(transition.action); target->score = transition.score; target->previous_ = transition.source; } lattice_index[round + 1] = lattice_index[round] + current_beam_size_; if (is_train) { CStateItem next_correct_state(*correct_state); next_correct_state.StandardMoveStep(oracle_tree #ifdef LABELED , m_lCacheLabel #endif // end for LABELED ); next_correct_state.previous_ = correct_state; is_correct = false; for (CStateItem *p = lattice_index[round]; p != lattice_index[round + 1]; ++ p) { if (next_correct_state.last_action == p->last_action && p->previous_ == correct_state) { correct_state = p; is_correct = true; break; } } #ifdef EARLY_UPDATE if (!is_correct) { TRACE("ERROR at the " << next_correct_state.size() << "th word;" << " Total is " << oracle_tree.size()); CStateItem * best_generator = lattice_index[round]; for (CStateItem * p = lattice_index[round]; p != lattice_index[round + 1]; ++ p) { if (best_generator->score < p->score) { best_generator = p; } } UpdateScoresForStates(best_generator, &next_correct_state, 1, -1); return -1; } #endif // end for EARLY_UPDATE } } if (is_train) { CStateItem * best_generator = lattice_index[round-1]; for (CStateItem * p = lattice_index[round-1]; p != lattice_index[round]; ++ p) { if (best_generator->score < p->score) { best_generator = p; } } if (best_generator != correct_state) { UpdateScoresForStates(best_generator, correct_state, 1, -1); } return -1; } if (!retval) { return -1; } TRACE("Output sentence"); std::sort(lattice_index[round - 1], lattice_index[round], StateMore); num_results = lattice_index[round] - lattice_index[round - 1]; for (int i = 0; i < std::min(num_results, nbest); ++ i) { assert( (lattice_index[round - 1] + i)->size() == m_lCache.size()); (lattice_index[round - 1] + i)->GenerateTree(sentence, retval[i]); if (scores) { scores[i] = (lattice_index[round - 1] + i)->score; } } TRACE("Done, total time spent: " << double(clock() - total_start_time) / CLOCKS_PER_SEC); return num_results; }