Beispiel #1
0
void
M2MFstAligner::write_lattice(string lattice)
{
    //Write out the entire training set in lattice format
    //Perform the union first.  This output can then
    // be plugged directly in to a counter to obtain expected
    // alignment counts for the EM-trained corpus.  Yields
    // far higher-quality joint n-gram models, which are also
    // more robust for smaller training corpora.
    //Make sure you call this BEFORE any call to
    // write_all_alignments
    // as the latter function will override some of the weights

    //Chaining the standard Union operation, including using a
    // rational FST still performs very poorly in the log semiring.
    //Presumably it's running push or something at each step.  It
    // should be fine to do that just once at the end.
    //Rolling our own union turns out to be MUCH faster.
    VectorFst<LogArc> ufst;
    ufst.AddState();
    ufst.SetStart(0);
    int total_states = 0;
    for (int i = 0; i < fsas.size(); i++) {
        TopSort(&fsas[i]);
        for (StateIterator<VectorFst<LogArc> > siter(fsas[i]);
                !siter.Done(); siter.Next()) {
            LogArc::StateId q = siter.Value();
            LogArc::StateId r;
            if (q == 0)
                r = 0;
            else
                r = ufst.AddState();

            for (ArcIterator <VectorFst<LogArc> > aiter(fsas[i], q);
                    !aiter.Done(); aiter.Next()) {
                const LogArc & arc = aiter.Value();
                ufst.AddArc(r,
                            LogArc(arc.ilabel, arc.ilabel, arc.weight,
                                   arc.nextstate + total_states));
            }
            if (fsas[i].Final(q) != LogWeight::Zero())
                ufst.SetFinal(r, LogWeight::One());
        }
        total_states += fsas[i].NumStates() - 1;
    }
    //Normalize weights
    Push(&ufst, REWEIGHT_TO_INITIAL);
    //Write the resulting lattice to disk
    ufst.Write(lattice);
    //Write the syms table too.
    isyms->WriteText("lattice.syms");
    return;
}
Beispiel #2
0
vector<PathData> M2MFstAligner::entry2alignfstnoinit(vector<string>
        seq1,
        vector<string>
        seq2, int nbest,
        string lattice)
{
    VectorFst<LogArc> fst;
    Sequences2FSTNoInit(&fst, &seq1, &seq2);
    if (lattice.compare("") != 0)
        fst.Write(lattice);
    return write_alignment(fst, nbest);
}
Beispiel #3
0
void
M2MFstAligner::write_model(string _model_file)
{
    VectorFst<LogArc> model;
    model.AddState();
    model.SetStart(0);
    model.SetFinal(0, LogWeight::One());
    map<LogArc::Label,LogWeight>::iterator it;
    for (it = alignment_model.begin(); it != alignment_model.end(); it++)
        model.AddArc(0, LogArc((*it).first, (*it).first, (*it).second, 0));
    model.SetInputSymbols(isyms);
    model.Write(_model_file);
    return;
}
Beispiel #4
0
void process(const FstClass& _fst, const char *output, const string& separator, const string* space) {
  const Fst<Arc>& fst = *_fst.GetFst<Arc>();
  Verify(fst);

  fst::SymbolTable * const_syms = new fst::SymbolTable("const syms");
  const_syms->AddSymbol("<s>");
  const_syms->AddSymbol("</s>");
  const_syms->AddSymbol("<space>");
  const_syms->AddSymbol("<phrase>");
  const_syms->AddSymbol("<epsilon>");
  const_syms->AddSymbol("!NULL");

  VectorFst<Arc> ofst;
  SplitSymbols<Arc>(fst, &ofst, separator, space, const_syms);

  delete const_syms;

  FstWriteOptions opts(output);
  ofilter os(output);
  ofst.Write(os, opts);
}
Beispiel #5
0
void M2MFstAligner::_conditional_max( bool y_given_x ){
  /*
    Compute the conditional distribution, P(Y|X) using the WFST paradigm.
    This is bassed on the approach from Shu and Hetherington 2002.
    It is assumed that all WFSTs and operations use the Log semiring.

    Given: 
           FST1 = P(X,Y)
    Compute:
           FST2 = P(X) 
             := Map_inv(Det(RmEps(Proj_i(FST1))))
           FST3 = P(Y|X)
             := Compose(FST2,FST1)

    Proj_i:  project on input labels
    RmEps:   epsilon removal
    Det:     determinize
    Map_inv: invert weights

    Notes: An analogous process may be used to compute P(X|Y).  In this
      case one would project on OUTPUT labels - Proj_o, and reverse the
      composition order to Compose(FST1,FST2).

    Future work:
      What we are doing here in terms of *generating* the JOINT fst each
      time is really just a dumb hack.  We *should* encode the model in an
      FST and encode the individual lattices, rather than doing the hacky
      manual label encoding that we currently rely on.
  */

  //Joint distribution that we start with
  VectorFst<LogArc>* joint  = new VectorFst<LogArc>();
  SymbolTable* misyms = new SymbolTable("misyms");
  SymbolTable* mosyms = new SymbolTable("mosyms");
  joint->AddState();
  joint->AddState();
  joint->SetStart(0);
  joint->SetFinal(1,LogArc::Weight::One());
  map<LogArc::Label,LogWeight>::iterator it;
  for( it=prev_alignment_model.begin(); it != prev_alignment_model.end(); it++ ){
    string isym = isyms->Find((*it).first); 
    vector<string> io = tokenize_utf8_string( &isym, &s1s2_sep );
    LogArc arc( misyms->AddSymbol(io[0]), mosyms->AddSymbol(io[1]), (*it).second, 1 );
    joint->AddArc( 0, arc );
  }
  //VectorFst<LogArc>* joint  = new VectorFst<LogArc>();
  //Push<LogArc,REWEIGHT_TO_FINAL>(*_joint, joint, kPushWeights);
  //joint->SetFinal(1,LogWeight::One());
  joint->Write("m2mjoint.fst");
  //BEGIN COMPUTE MARGINAL P(X)  
  VectorFst<LogArc>* dmarg;
  if( y_given_x )
    dmarg = new VectorFst<LogArc>(ProjectFst<LogArc>(*joint, PROJECT_INPUT));
  else
    dmarg = new VectorFst<LogArc>(ProjectFst<LogArc>(*joint, PROJECT_OUTPUT));

  RmEpsilon(dmarg);
  VectorFst<LogArc>* marg = new VectorFst<LogArc>();
  Determinize(*dmarg, marg);
  ArcMap(marg, InvertWeightMapper<LogArc>());

  if( y_given_x )
    ArcSort(marg, OLabelCompare<LogArc>());
  else
    ArcSort(marg, ILabelCompare<LogArc>());
  //END COMPUTE MARGINAL P(X)
  marg->Write("marg.fst");

  //CONDITIONAL P(Y|X)
  VectorFst<LogArc>* cond = new VectorFst<LogArc>();
  if( y_given_x )
    Compose(*marg, *joint, cond);
  else
    Compose(*joint, *marg, cond);
  //cond now contains the conditional distribution P(Y|X)
  cond->Write("cond.fst");
  //Now update the model with the new values
  for( MutableArcIterator<VectorFst<LogArc> > aiter(cond, 0); !aiter.Done(); aiter.Next() ){
    LogArc arc = aiter.Value();
    string lab = misyms->Find(arc.ilabel)+"}"+mosyms->Find(arc.olabel);
    int   labi = isyms->Find(lab);
    alignment_model[labi]      = arc.weight;
    prev_alignment_model[labi] = LogWeight::Zero();
  }
  delete joint, marg, cond, dmarg;
  delete misyms, mosyms;
  return;
}
Beispiel #6
0
void
train_model(string eps, string s1s2_sep, string skip, int order,
            string smooth, string prefix, string seq_sep, string prune,
            double theta, string count_pattern)
{
    namespace s = fst::script;
    using fst::script::FstClass;
    using fst::script::MutableFstClass;
    using fst::script::VectorFstClass;
    using fst::script::WeightClass;

    // create symbols file
    cout << "Generating symbols..." << endl;
    NGramInput *ingram =
        new NGramInput(prefix + ".corpus.aligned", prefix + ".corpus.syms",
                       "", eps, unknown_symbol, "", "");
    ingram->ReadInput(0, 1);

    // compile strings into a far archive
    cout << "Compiling symbols into FAR archive..." << endl;
    fst::FarEntryType fet = fst::StringToFarEntryType(entry_type);
    fst::FarTokenType ftt = fst::StringToFarTokenType(token_type);
    fst::FarType fartype = fst::FarTypeFromString(far_type);

    delete ingram;

    vector<string> in_fname;
    in_fname.push_back(prefix + ".corpus.aligned");

    fst::script::FarCompileStrings(in_fname, prefix + ".corpus.far",
                                   arc_type, fst_type, fartype,
                                   generate_keys, fet, ftt,
                                   prefix + ".corpus.syms", unknown_symbol,
                                   keep_symbols, initial_symbols,
                                   allow_negative_labels, file_list_input,
                                   key_prefix, key_suffix);

    //count n-grams
    cout << "Counting n-grams..." << endl;
    NGramCounter<Log64Weight> ngram_counter(order, epsilon_as_backoff);

    FstReadOptions opts;
    FarReader<StdArc> *far_reader;
    far_reader = FarReader<StdArc>::Open(prefix + ".corpus.far");
    int fstnumber = 1;
    const Fst<StdArc> *ifst = 0, *lfst = 0;
    while (!far_reader->Done()) {
        if (ifst)
            delete ifst;
        ifst = far_reader->GetFst().Copy();

        if (!ifst) {
            E_FATAL("ngramcount: unable to read fst #%d\n", fstnumber);
            //exit(1);
        }

        bool counted = false;
        if (ifst->Properties(kString | kUnweighted, true)) {
            counted = ngram_counter.Count(*ifst);
        }
        else {
            VectorFst<Log64Arc> log_ifst;
            Map(*ifst, &log_ifst, ToLog64Mapper<StdArc> ());
            counted = ngram_counter.Count(&log_ifst);
        }
        if (!counted)
            cout << "ngramcount: fst #" << fstnumber << endl;

        if (ifst->InputSymbols() != 0) {        // retain for symbol table
            if (lfst)
                delete lfst;    // delete previously observed symbol table
            lfst = ifst;
            ifst = 0;
        }
        far_reader->Next();
        ++fstnumber;
    }
    delete far_reader;

    if (!lfst) {
        E_FATAL("None of the input FSTs had a symbol table\n");
        //exit(1);
    }

    VectorFst<StdArc> vfst;
    ngram_counter.GetFst(&vfst);
    ArcSort(&vfst, StdILabelCompare());
    vfst.SetInputSymbols(lfst->InputSymbols());
    vfst.SetOutputSymbols(lfst->InputSymbols());
    vfst.Write(prefix + ".corpus.cnts");
    StdMutableFst *fst =
        StdMutableFst::Read(prefix + ".corpus.cnts", true);
    if (smooth != "no") {
        cout << "Smoothing model..." << endl;

        bool prefix_norm = 0;
        if (smooth == "presmoothed") {  // only for use with randgen counts
            prefix_norm = 1;
            smooth = "unsmoothed";      // normalizes only based on prefix count
        }
        if (smooth == "kneser_ney") {
            NGramKneserNey ngram(fst, backoff, backoff_label,
                                 norm_eps, check_consistency,
                                 discount_D, bins);
            ngram.MakeNGramModel();
            fst = ngram.GetMutableFst();
        }
        else if (smooth == "absolute") {
            NGramAbsolute ngram(fst, backoff, backoff_label,
                                norm_eps, check_consistency,
                                discount_D, bins);
            ngram.MakeNGramModel();
            fst = ngram.GetMutableFst();
        }
        else if (smooth == "katz") {
            NGramKatz ngram(fst, backoff, backoff_label,
                            norm_eps, check_consistency, bins);
            ngram.MakeNGramModel();
            fst = ngram.GetMutableFst();
        }
        else if (smooth == "witten_bell") {
            NGramWittenBell ngram(fst, backoff, backoff_label,
                                  norm_eps, check_consistency,
                                  witten_bell_k);
            ngram.MakeNGramModel();
            fst = ngram.GetMutableFst();
        }
        else if (smooth == "unsmoothed") {
            NGramUnsmoothed ngram(fst, 1, prefix_norm, backoff_label,
                                  norm_eps, check_consistency);
            ngram.MakeNGramModel();
            fst = ngram.GetMutableFst();
        }
        else {
            E_FATAL("Bad smoothing method: %s\n", smooth.c_str());
        }
    }
    if (prune != "no") {
        cout << "Pruning model..." << endl;

        if (prune == "count_prune") {
            NGramCountPrune ngramsh(fst, count_pattern,
                                    shrink_opt, total_unigram_count,
                                    backoff_label, norm_eps,
                                    check_consistency);
            ngramsh.ShrinkNGramModel();
        }
        else if (prune == "relative_entropy") {
            NGramRelEntropy ngramsh(fst, theta, shrink_opt,
                                    total_unigram_count, backoff_label,
                                    norm_eps, check_consistency);
            ngramsh.ShrinkNGramModel();
        }
        else if (prune == "seymore") {
            NGramSeymoreShrink ngramsh(fst, theta, shrink_opt,
                                       total_unigram_count, backoff_label,
                                       norm_eps, check_consistency);
            ngramsh.ShrinkNGramModel();
        }
        else {
            E_FATAL("Bad shrink method:  %s\n", prune.c_str());
        }
    }

    cout << "Minimizing model..." << endl;
    MutableFstClass *minimized = new s::MutableFstClass(*fst);
    Minimize(minimized, 0, fst::kDelta);
    fst = minimized->GetMutableFst<StdArc>();

    cout << "Correcting final model..." << endl;
    StdMutableFst *out = new StdVectorFst();
    relabel(fst, out, prefix, eps, skip, s1s2_sep, seq_sep);

    cout << "Writing binary model to disk..." << endl;
    out->Write(prefix + ".fst");
}