void FunctionTable::Init(SymbolTable& tbl) { std::cout<<"function list: "<<std::endl; ptrFun_ = new PtrFun[size_]; for (unsigned int i = 0; i < size_; i++) { ptrFun_[i] = functionTable[i].ptrfun; unsigned int j = tbl.AddSymbol(functionTable[i].name); assert(i == j); std::cout<<functionTable[i].name<<std::endl; } std::cout<<std::endl; }
void M2MFstAligner::_conditional_max( bool y_given_x ){ /* Compute the conditional distribution, P(Y|X) using the WFST paradigm. This is bassed on the approach from Shu and Hetherington 2002. It is assumed that all WFSTs and operations use the Log semiring. Given: FST1 = P(X,Y) Compute: FST2 = P(X) := Map_inv(Det(RmEps(Proj_i(FST1)))) FST3 = P(Y|X) := Compose(FST2,FST1) Proj_i: project on input labels RmEps: epsilon removal Det: determinize Map_inv: invert weights Notes: An analogous process may be used to compute P(X|Y). In this case one would project on OUTPUT labels - Proj_o, and reverse the composition order to Compose(FST1,FST2). Future work: What we are doing here in terms of *generating* the JOINT fst each time is really just a dumb hack. We *should* encode the model in an FST and encode the individual lattices, rather than doing the hacky manual label encoding that we currently rely on. */ //Joint distribution that we start with VectorFst<LogArc>* joint = new VectorFst<LogArc>(); SymbolTable* misyms = new SymbolTable("misyms"); SymbolTable* mosyms = new SymbolTable("mosyms"); joint->AddState(); joint->AddState(); joint->SetStart(0); joint->SetFinal(1,LogArc::Weight::One()); map<LogArc::Label,LogWeight>::iterator it; for( it=prev_alignment_model.begin(); it != prev_alignment_model.end(); it++ ){ string isym = isyms->Find((*it).first); vector<string> io = tokenize_utf8_string( &isym, &s1s2_sep ); LogArc arc( misyms->AddSymbol(io[0]), mosyms->AddSymbol(io[1]), (*it).second, 1 ); joint->AddArc( 0, arc ); } //VectorFst<LogArc>* joint = new VectorFst<LogArc>(); //Push<LogArc,REWEIGHT_TO_FINAL>(*_joint, joint, kPushWeights); //joint->SetFinal(1,LogWeight::One()); joint->Write("m2mjoint.fst"); //BEGIN COMPUTE MARGINAL P(X) VectorFst<LogArc>* dmarg; if( y_given_x ) dmarg = new VectorFst<LogArc>(ProjectFst<LogArc>(*joint, PROJECT_INPUT)); else dmarg = new VectorFst<LogArc>(ProjectFst<LogArc>(*joint, PROJECT_OUTPUT)); RmEpsilon(dmarg); VectorFst<LogArc>* marg = new VectorFst<LogArc>(); Determinize(*dmarg, marg); ArcMap(marg, InvertWeightMapper<LogArc>()); if( y_given_x ) ArcSort(marg, OLabelCompare<LogArc>()); else ArcSort(marg, ILabelCompare<LogArc>()); //END COMPUTE MARGINAL P(X) marg->Write("marg.fst"); //CONDITIONAL P(Y|X) VectorFst<LogArc>* cond = new VectorFst<LogArc>(); if( y_given_x ) Compose(*marg, *joint, cond); else Compose(*joint, *marg, cond); //cond now contains the conditional distribution P(Y|X) cond->Write("cond.fst"); //Now update the model with the new values for( MutableArcIterator<VectorFst<LogArc> > aiter(cond, 0); !aiter.Done(); aiter.Next() ){ LogArc arc = aiter.Value(); string lab = misyms->Find(arc.ilabel)+"}"+mosyms->Find(arc.olabel); int labi = isyms->Find(lab); alignment_model[labi] = arc.weight; prev_alignment_model[labi] = LogWeight::Zero(); } delete joint, marg, cond, dmarg; delete misyms, mosyms; return; }
void relabel(StdMutableFst * fst, StdMutableFst * out, string prefix, string eps, string skip, string s1s2_sep, string seq_sep) { namespace s = fst::script; using fst::ostream; using fst::SymbolTable; ArcSort(fst, StdILabelCompare()); const SymbolTable *oldsyms = fst->InputSymbols(); // generate new input, output and states SymbolTables SymbolTable *ssyms = new SymbolTable("ssyms"); SymbolTable *isyms = new SymbolTable("isyms"); SymbolTable *osyms = new SymbolTable("osyms"); out->AddState(); ssyms->AddSymbol("s0"); out->SetStart(0); out->AddState(); ssyms->AddSymbol("s1"); out->SetFinal(1, TropicalWeight::One()); isyms->AddSymbol(eps); osyms->AddSymbol(eps); //Add separator, phi, start and end symbols isyms->AddSymbol(seq_sep); osyms->AddSymbol(seq_sep); isyms->AddSymbol("<phi>"); osyms->AddSymbol("<phi>"); int istart = isyms->AddSymbol("<s>"); int iend = isyms->AddSymbol("</s>"); int ostart = osyms->AddSymbol("<s>"); int oend = osyms->AddSymbol("</s>"); out->AddState(); ssyms->AddSymbol("s2"); out->AddArc(0, StdArc(istart, ostart, TropicalWeight::One(), 2)); for (StateIterator<StdFst> siter(*fst); !siter.Done(); siter.Next()) { StateId state_id = siter.Value(); int64 newstate; if (state_id == fst->Start()) { newstate = 2; } else { newstate = ssyms->Find(convertInt(state_id)); if (newstate == -1) { out->AddState(); ssyms->AddSymbol(convertInt(state_id)); newstate = ssyms->Find(convertInt(state_id)); } } TropicalWeight weight = fst->Final(state_id); if (weight != TropicalWeight::Zero()) { // this is a final state StdArc a = StdArc(iend, oend, weight, 1); out->AddArc(newstate, a); out->SetFinal(newstate, TropicalWeight::Zero()); } addarcs(state_id, newstate, oldsyms, isyms, osyms, ssyms, eps, s1s2_sep, fst, out); } out->SetInputSymbols(isyms); out->SetOutputSymbols(osyms); cout << "Writing text model to disk..." << endl; //Save syms tables isyms->WriteText(prefix + ".input.syms"); osyms->WriteText(prefix + ".output.syms"); string dest = prefix + ".fst.txt"; fst::ofstream ostrm(dest.c_str()); ostrm.precision(9); s::FstClass fstc(*out); s::PrintFst(fstc, ostrm, dest, isyms, osyms, NULL, acceptor, show_weight_one); ostrm.flush(); }
int main(int argc, char* argv[]) { StdVectorFst fst; SymbolTable* isyms; SymbolTable* osyms; { isyms = new SymbolTable("isyms.txt"); osyms = new SymbolTable("osyms.txt"); isyms->AddSymbol("a"); isyms->AddSymbol("b"); isyms->AddSymbol("c"); isyms->Write("isyms.txt"); osyms->AddSymbol("x"); osyms->AddSymbol("y"); osyms->AddSymbol("z"); osyms->Write("osyms.txt"); } { fst.SetInputSymbols(isyms); fst.SetOutputSymbols(osyms); // Adds state 0 to the initially empty FST and make it the start state. fst.AddState(); // 1st state will be state 0 (returned by AddState) fst.SetStart(0); // arg is state ID // Adds two arcs exiting state 0. // Arc constructor args: ilabel, olabel, weight, dest state ID. fst.AddArc(0, StdArc(isyms->Find("a"), osyms->Find("x"), 0.5, 1)); // 1st arg is src state ID fst.AddArc(0, StdArc(isyms->Find("b"), osyms->Find("y"), 1.5, 1)); // Adds state 1 and its arc. fst.AddState(); fst.AddArc(1, StdArc(isyms->Find("c"), osyms->Find("z"), 2.5, 2)); // Adds state 2 and set its final weight. fst.AddState(); fst.SetFinal(2, 3.5); // 1st arg is state ID, 2nd arg weight fst.Write("example.fst"); } StdVectorFst search_fst; { search_fst.SetInputSymbols(isyms); search_fst.SetOutputSymbols(osyms); search_fst.AddState(); // 1st state will be state 0 (returned by AddState) search_fst.SetStart(0); // arg is state ID // Adds two arcs exiting state 0. // Arc constructor args: ilabel, olabel, weight, dest state ID. search_fst.AddArc(0, StdArc(isyms->Find("a"), osyms->Find("x"), 0.5, 1)); // 1st arg is src state ID // Adds state 1 and its arc. search_fst.AddState(); search_fst.AddArc(1, StdArc(isyms->Find("c"), osyms->Find("z"), 2.5, 2)); // Adds state 2 and set its final weight. search_fst.AddState(); search_fst.SetFinal(2, 3.5); // 1st arg is state ID, 2nd arg weight } { for (StateIterator<StdVectorFst> siter(fst); not siter.Done(); siter.Next()) { StdIntersectFst::StateId s = siter.Value(); std::cout << "state=" << s << ":"; for (ArcIterator<StdVectorFst> aiter(fst, s); not aiter.Done(); aiter.Next()) { const StdArc& arc = aiter.Value(); std::cout << arc.ilabel << "/" << arc.olabel << "->" << arc.nextstate << ","; } std::cout << std::endl; } } { Matcher<StdVectorFst> matcher(fst, MATCH_INPUT); matcher.SetState(0); StdArc::Label find_label = 1; if (matcher.Find(find_label)) { for (; not matcher.Done(); matcher.Next()) { const StdArc& arc = matcher.Value(); std::cout << "found=" << arc.ilabel << "/" << arc.olabel << "->" << arc.nextstate << std::endl; } } } // intersect contains strings in both A and B { ArcSort(&fst, StdOLabelCompare()); ArcSort(&search_fst, StdILabelCompare()); /* ComposeFilter compose_filter; if (!GetComposeFilter("auto", &compose_filter)) { LOG(ERROR) << "failed"; exit(1); } const fst::IntersectFstOptions<StdArc> opts; */ //StdIntersectFst ofst(fst, search_fst, opts); StdIntersectFst ofst(fst, search_fst); } }