void M2MFstAligner::write_lattice(string lattice) { //Write out the entire training set in lattice format //Perform the union first. This output can then // be plugged directly in to a counter to obtain expected // alignment counts for the EM-trained corpus. Yields // far higher-quality joint n-gram models, which are also // more robust for smaller training corpora. //Make sure you call this BEFORE any call to // write_all_alignments // as the latter function will override some of the weights //Chaining the standard Union operation, including using a // rational FST still performs very poorly in the log semiring. //Presumably it's running push or something at each step. It // should be fine to do that just once at the end. //Rolling our own union turns out to be MUCH faster. VectorFst<LogArc> ufst; ufst.AddState(); ufst.SetStart(0); int total_states = 0; for (int i = 0; i < fsas.size(); i++) { TopSort(&fsas[i]); for (StateIterator<VectorFst<LogArc> > siter(fsas[i]); !siter.Done(); siter.Next()) { LogArc::StateId q = siter.Value(); LogArc::StateId r; if (q == 0) r = 0; else r = ufst.AddState(); for (ArcIterator <VectorFst<LogArc> > aiter(fsas[i], q); !aiter.Done(); aiter.Next()) { const LogArc & arc = aiter.Value(); ufst.AddArc(r, LogArc(arc.ilabel, arc.ilabel, arc.weight, arc.nextstate + total_states)); } if (fsas[i].Final(q) != LogWeight::Zero()) ufst.SetFinal(r, LogWeight::One()); } total_states += fsas[i].NumStates() - 1; } //Normalize weights Push(&ufst, REWEIGHT_TO_INITIAL); //Write the resulting lattice to disk ufst.Write(lattice); //Write the syms table too. isyms->WriteText("lattice.syms"); return; }
vector<PathData> M2MFstAligner::entry2alignfstnoinit(vector<string> seq1, vector<string> seq2, int nbest, string lattice) { VectorFst<LogArc> fst; Sequences2FSTNoInit(&fst, &seq1, &seq2); if (lattice.compare("") != 0) fst.Write(lattice); return write_alignment(fst, nbest); }
void M2MFstAligner::write_model(string _model_file) { VectorFst<LogArc> model; model.AddState(); model.SetStart(0); model.SetFinal(0, LogWeight::One()); map<LogArc::Label,LogWeight>::iterator it; for (it = alignment_model.begin(); it != alignment_model.end(); it++) model.AddArc(0, LogArc((*it).first, (*it).first, (*it).second, 0)); model.SetInputSymbols(isyms); model.Write(_model_file); return; }
void process(const FstClass& _fst, const char *output, const string& separator, const string* space) { const Fst<Arc>& fst = *_fst.GetFst<Arc>(); Verify(fst); fst::SymbolTable * const_syms = new fst::SymbolTable("const syms"); const_syms->AddSymbol("<s>"); const_syms->AddSymbol("</s>"); const_syms->AddSymbol("<space>"); const_syms->AddSymbol("<phrase>"); const_syms->AddSymbol("<epsilon>"); const_syms->AddSymbol("!NULL"); VectorFst<Arc> ofst; SplitSymbols<Arc>(fst, &ofst, separator, space, const_syms); delete const_syms; FstWriteOptions opts(output); ofilter os(output); ofst.Write(os, opts); }
void M2MFstAligner::_conditional_max( bool y_given_x ){ /* Compute the conditional distribution, P(Y|X) using the WFST paradigm. This is bassed on the approach from Shu and Hetherington 2002. It is assumed that all WFSTs and operations use the Log semiring. Given: FST1 = P(X,Y) Compute: FST2 = P(X) := Map_inv(Det(RmEps(Proj_i(FST1)))) FST3 = P(Y|X) := Compose(FST2,FST1) Proj_i: project on input labels RmEps: epsilon removal Det: determinize Map_inv: invert weights Notes: An analogous process may be used to compute P(X|Y). In this case one would project on OUTPUT labels - Proj_o, and reverse the composition order to Compose(FST1,FST2). Future work: What we are doing here in terms of *generating* the JOINT fst each time is really just a dumb hack. We *should* encode the model in an FST and encode the individual lattices, rather than doing the hacky manual label encoding that we currently rely on. */ //Joint distribution that we start with VectorFst<LogArc>* joint = new VectorFst<LogArc>(); SymbolTable* misyms = new SymbolTable("misyms"); SymbolTable* mosyms = new SymbolTable("mosyms"); joint->AddState(); joint->AddState(); joint->SetStart(0); joint->SetFinal(1,LogArc::Weight::One()); map<LogArc::Label,LogWeight>::iterator it; for( it=prev_alignment_model.begin(); it != prev_alignment_model.end(); it++ ){ string isym = isyms->Find((*it).first); vector<string> io = tokenize_utf8_string( &isym, &s1s2_sep ); LogArc arc( misyms->AddSymbol(io[0]), mosyms->AddSymbol(io[1]), (*it).second, 1 ); joint->AddArc( 0, arc ); } //VectorFst<LogArc>* joint = new VectorFst<LogArc>(); //Push<LogArc,REWEIGHT_TO_FINAL>(*_joint, joint, kPushWeights); //joint->SetFinal(1,LogWeight::One()); joint->Write("m2mjoint.fst"); //BEGIN COMPUTE MARGINAL P(X) VectorFst<LogArc>* dmarg; if( y_given_x ) dmarg = new VectorFst<LogArc>(ProjectFst<LogArc>(*joint, PROJECT_INPUT)); else dmarg = new VectorFst<LogArc>(ProjectFst<LogArc>(*joint, PROJECT_OUTPUT)); RmEpsilon(dmarg); VectorFst<LogArc>* marg = new VectorFst<LogArc>(); Determinize(*dmarg, marg); ArcMap(marg, InvertWeightMapper<LogArc>()); if( y_given_x ) ArcSort(marg, OLabelCompare<LogArc>()); else ArcSort(marg, ILabelCompare<LogArc>()); //END COMPUTE MARGINAL P(X) marg->Write("marg.fst"); //CONDITIONAL P(Y|X) VectorFst<LogArc>* cond = new VectorFst<LogArc>(); if( y_given_x ) Compose(*marg, *joint, cond); else Compose(*joint, *marg, cond); //cond now contains the conditional distribution P(Y|X) cond->Write("cond.fst"); //Now update the model with the new values for( MutableArcIterator<VectorFst<LogArc> > aiter(cond, 0); !aiter.Done(); aiter.Next() ){ LogArc arc = aiter.Value(); string lab = misyms->Find(arc.ilabel)+"}"+mosyms->Find(arc.olabel); int labi = isyms->Find(lab); alignment_model[labi] = arc.weight; prev_alignment_model[labi] = LogWeight::Zero(); } delete joint, marg, cond, dmarg; delete misyms, mosyms; return; }
void train_model(string eps, string s1s2_sep, string skip, int order, string smooth, string prefix, string seq_sep, string prune, double theta, string count_pattern) { namespace s = fst::script; using fst::script::FstClass; using fst::script::MutableFstClass; using fst::script::VectorFstClass; using fst::script::WeightClass; // create symbols file cout << "Generating symbols..." << endl; NGramInput *ingram = new NGramInput(prefix + ".corpus.aligned", prefix + ".corpus.syms", "", eps, unknown_symbol, "", ""); ingram->ReadInput(0, 1); // compile strings into a far archive cout << "Compiling symbols into FAR archive..." << endl; fst::FarEntryType fet = fst::StringToFarEntryType(entry_type); fst::FarTokenType ftt = fst::StringToFarTokenType(token_type); fst::FarType fartype = fst::FarTypeFromString(far_type); delete ingram; vector<string> in_fname; in_fname.push_back(prefix + ".corpus.aligned"); fst::script::FarCompileStrings(in_fname, prefix + ".corpus.far", arc_type, fst_type, fartype, generate_keys, fet, ftt, prefix + ".corpus.syms", unknown_symbol, keep_symbols, initial_symbols, allow_negative_labels, file_list_input, key_prefix, key_suffix); //count n-grams cout << "Counting n-grams..." << endl; NGramCounter<Log64Weight> ngram_counter(order, epsilon_as_backoff); FstReadOptions opts; FarReader<StdArc> *far_reader; far_reader = FarReader<StdArc>::Open(prefix + ".corpus.far"); int fstnumber = 1; const Fst<StdArc> *ifst = 0, *lfst = 0; while (!far_reader->Done()) { if (ifst) delete ifst; ifst = far_reader->GetFst().Copy(); if (!ifst) { E_FATAL("ngramcount: unable to read fst #%d\n", fstnumber); //exit(1); } bool counted = false; if (ifst->Properties(kString | kUnweighted, true)) { counted = ngram_counter.Count(*ifst); } else { VectorFst<Log64Arc> log_ifst; Map(*ifst, &log_ifst, ToLog64Mapper<StdArc> ()); counted = ngram_counter.Count(&log_ifst); } if (!counted) cout << "ngramcount: fst #" << fstnumber << endl; if (ifst->InputSymbols() != 0) { // retain for symbol table if (lfst) delete lfst; // delete previously observed symbol table lfst = ifst; ifst = 0; } far_reader->Next(); ++fstnumber; } delete far_reader; if (!lfst) { E_FATAL("None of the input FSTs had a symbol table\n"); //exit(1); } VectorFst<StdArc> vfst; ngram_counter.GetFst(&vfst); ArcSort(&vfst, StdILabelCompare()); vfst.SetInputSymbols(lfst->InputSymbols()); vfst.SetOutputSymbols(lfst->InputSymbols()); vfst.Write(prefix + ".corpus.cnts"); StdMutableFst *fst = StdMutableFst::Read(prefix + ".corpus.cnts", true); if (smooth != "no") { cout << "Smoothing model..." << endl; bool prefix_norm = 0; if (smooth == "presmoothed") { // only for use with randgen counts prefix_norm = 1; smooth = "unsmoothed"; // normalizes only based on prefix count } if (smooth == "kneser_ney") { NGramKneserNey ngram(fst, backoff, backoff_label, norm_eps, check_consistency, discount_D, bins); ngram.MakeNGramModel(); fst = ngram.GetMutableFst(); } else if (smooth == "absolute") { NGramAbsolute ngram(fst, backoff, backoff_label, norm_eps, check_consistency, discount_D, bins); ngram.MakeNGramModel(); fst = ngram.GetMutableFst(); } else if (smooth == "katz") { NGramKatz ngram(fst, backoff, backoff_label, norm_eps, check_consistency, bins); ngram.MakeNGramModel(); fst = ngram.GetMutableFst(); } else if (smooth == "witten_bell") { NGramWittenBell ngram(fst, backoff, backoff_label, norm_eps, check_consistency, witten_bell_k); ngram.MakeNGramModel(); fst = ngram.GetMutableFst(); } else if (smooth == "unsmoothed") { NGramUnsmoothed ngram(fst, 1, prefix_norm, backoff_label, norm_eps, check_consistency); ngram.MakeNGramModel(); fst = ngram.GetMutableFst(); } else { E_FATAL("Bad smoothing method: %s\n", smooth.c_str()); } } if (prune != "no") { cout << "Pruning model..." << endl; if (prune == "count_prune") { NGramCountPrune ngramsh(fst, count_pattern, shrink_opt, total_unigram_count, backoff_label, norm_eps, check_consistency); ngramsh.ShrinkNGramModel(); } else if (prune == "relative_entropy") { NGramRelEntropy ngramsh(fst, theta, shrink_opt, total_unigram_count, backoff_label, norm_eps, check_consistency); ngramsh.ShrinkNGramModel(); } else if (prune == "seymore") { NGramSeymoreShrink ngramsh(fst, theta, shrink_opt, total_unigram_count, backoff_label, norm_eps, check_consistency); ngramsh.ShrinkNGramModel(); } else { E_FATAL("Bad shrink method: %s\n", prune.c_str()); } } cout << "Minimizing model..." << endl; MutableFstClass *minimized = new s::MutableFstClass(*fst); Minimize(minimized, 0, fst::kDelta); fst = minimized->GetMutableFst<StdArc>(); cout << "Correcting final model..." << endl; StdMutableFst *out = new StdVectorFst(); relabel(fst, out, prefix, eps, skip, s1s2_sep, seq_sep); cout << "Writing binary model to disk..." << endl; out->Write(prefix + ".fst"); }