コード例 #1
0
ファイル: mixture.cpp プロジェクト: MistSC/kaldi-trunk
int mixture::loadpar(char* ipf)
{

  mfstream inp(ipf,ios::in);

  if (!inp) {
		std::stringstream ss_msg;
		ss_msg << "cannot open file: " << ipf;
		exit_error(IRSTLM_ERROR_IO, ss_msg.str());
  }

  cerr << "loading parameters from " << ipf << "\n";

  // check compatibility
  char header[100];
  inp.getline(header,100);
  int value1,value2;
  sscanf(header,"%d %d",&value1,&value2);

  if (value1 != lmsize() || value2 != pmax) {
		std::stringstream ss_msg;
		ss_msg << "parameter file " << ipf << " is incompatible";
		exit_error(IRSTLM_ERROR_DATA, ss_msg.str());
  }

  for (int i=0; i<=lmsize(); i++)
    for (int j=0; j<pmax; j++)
      inp.readx(l[i][j],sizeof(double),numslm);

  return 1;
}
コード例 #2
0
ファイル: mixture.cpp プロジェクト: MistSC/kaldi-trunk
int mixture::savepar(char* opf)
{
  mfstream out(opf,ios::out);

  cerr << "saving parameters in " << opf << "\n";
  out << lmsize() << " " << pmax << "\n";

  for (int i=0; i<=lmsize(); i++)
    for (int j=0; j<pmax; j++)
      out.writex(l[i][j],sizeof(double),numslm);


  return 1;
}
コード例 #3
0
ファイル: shiftlm.cpp プロジェクト: shyamjvs/cs626_project
shiftbeta::shiftbeta(char* ngtfile,int depth,int prunefreq,double b,TABLETYPE tt):
  mdiadaptlm(ngtfile,depth,tt)
{
  cerr << "Creating LM with ShiftBeta smoothing\n";

  if (b==-1.0 || (b < 1.0 && b >0.0)) {
    beta=new double[lmsize()+1];
    for (int l=lmsize(); l>1; l--)
      beta[l]=b;
  } else {
    cerr << "shiftbeta: beta must be < 1.0 and > 0\n";
    exit (1);
  }

  prunethresh=prunefreq;
  cerr << "PruneThresh: " << prunethresh << "\n";
};
コード例 #4
0
ファイル: mixture.cpp プロジェクト: MistSC/kaldi-trunk
int mixture::train()
{

  double zf;

  srand(1333);

  genpmap();

  if (dub()<dict->size()) {
		std::stringstream ss_msg;
    ss_msg << "\nERROR: DUB value is too small: the LM will possibly compute wrong probabilities if sub-LMs have different vocabularies!\n";
		ss_msg << "This exception should already have been handled before!!!\n";
		exit_error(IRSTLM_ERROR_MODEL, ss_msg.str());
  }

  cerr << "mixlm --> DUB: " << dub() << endl;
  for (int i=0; i<numslm; i++) {
    cerr << i << " sublm --> DUB: " << sublm[i]->dub()  << endl;
    cerr << "eventually generate OOV code ";
    cerr << sublm[i]->dict->encode(sublm[i]->dict->OOV()) << "\n";
    sublm[i]->train();
  }

  //initialize parameters

  for (int i=0; i<=lmsize(); i++) {
    l[i]=new double*[pmax];
    for (int j=0; j<pmax; j++) {
      l[i][j]=new double[numslm];
      for (int k=0; k<numslm; k++)
        l[i][j][k]=1.0/(double)numslm;
    }
  }

  if (ipfname) {
    //load parameters from file
    loadpar(ipfname);
  } else {
    //start training of mixture model

    double oldl[pmax][numslm];
    char alive[pmax],used[pmax];
    int totalive;

    ngram ng(sublm[0]->dict);

    for (int lev=1; lev<=lmsize(); lev++) {

      zf=sublm[0]->zerofreq(lev);

      cerr << "Starting training at lev:" << lev << "\n";

      for (int i=0; i<pmax; i++) {
        alive[i]=1;
        used[i]=0;
      }
      totalive=1;
      int iter=0;
      while (totalive && (iter < 20) ) {

        iter++;

        for (int i=0; i<pmax; i++)
          if (alive[i])
            for (int j=0; j<numslm; j++) {
              oldl[i][j]=l[lev][i][j];
              l[lev][i][j]=1.0/(double)numslm;
            }

        sublm[0]->scan(ng,INIT,lev);
        while(sublm[0]->scan(ng,CONT,lev)) {

          //do not include oov for unigrams
          if ((lev==1) && (*ng.wordp(1)==sublm[0]->dict->oovcode()))
            continue;

          int par=pmap(ng,lev);
          used[par]=1;

          //controllo se aggiornare il parametro
          if (alive[par]) {

            double backoff=(lev>1?prob(ng,lev-1):1); //backoff
            double denom=0.0;
            double* numer = new double[numslm];
						double fstar,lambda;

            //int cv=(int)floor(zf * (double)ng.freq + rand01());
            //int cv=1; //old version of leaving-one-out
            int cv=(int)floor(zf * (double)ng.freq)+1;
            //int cv=1; //old version of leaving-one-out
            //if (lev==3)q

            //if (iter>10)
            // cout << ng
            // << " backoff " << backoff
            // << " level " << lev
            // << "\n";

            for (int i=0; i<numslm; i++) {

              //use cv if i=0

              sublm[i]->discount(ng,lev,fstar,lambda,(i==0)*(cv));
              numer[i]=oldl[par][i]*(fstar + lambda * backoff);

              ngram ngslm(sublm[i]->dict);
              ngslm.trans(ng);
              if ((*ngslm.wordp(1)==sublm[i]->dict->oovcode()) &&
                  (dict->dub() > sublm[i]->dict->size()))
                numer[i]/=(double)(dict->dub() - sublm[i]->dict->size());

              denom+=numer[i];
            }

            for (int i=0; i<numslm; i++) {
              l[lev][par][i]+=(ng.freq * (numer[i]/denom));
              //if (iter>10)
              //cout << ng << " l: " << l[lev][par][i] << "\n";
            }
						delete []numer;
          }
        }

        //normalize all parameters
        totalive=0;
        for (int i=0; i<pmax; i++) {
          double tot=0;
          if (alive[i]) {
            for (int j=0; j<numslm; j++) tot+=(l[lev][i][j]);
            for (int j=0; j<numslm; j++) l[lev][i][j]/=tot;

            //decide if to continue to update
            if (!used[i] || (reldist(l[lev][i],oldl[i],numslm)<=0.05))
              alive[i]=0;
          }
          totalive+=alive[i];
        }

        cerr << "Lev " << lev << " iter " << iter << " tot alive " << totalive << "\n";

      }
    }
  }

  if (opfname) savepar(opfname);


  return 1;
}
コード例 #5
0
ファイル: shiftlm.cpp プロジェクト: shyamjvs/cs626_project
int mshiftbeta::discount(ngram ng_,int size,double& fstar,double& lambda, int cv)
{
  ngram ng(dict);
  ng.trans(ng_);

  //cout << "size :" << size << " " << ng <<"\n";

  if (size > 1) {

    ngram history=ng;

    //singleton pruning only on real counts!!
    if (ng.ckhisto(size) && get(history,size,size-1) && (history.freq > cv) &&
        ((size < 3) || ((history.freq-cv) > prunethresh ))) { // no history pruning with corrected counts!

      int suc[3];
      suc[0]=succ1(history.link);
      suc[1]=succ2(history.link);
      suc[2]=history.succ-suc[0]-suc[1];


      if (get(ng,size,size) &&
          (!prunesingletons() || mfreq(ng,size)>1 || size<3) &&
          (!prunetopsingletons() || mfreq(ng,size)>1 || size<maxlevel())) {

        ng.freq=mfreq(ng,size);

        cv=(cv>ng.freq)?ng.freq:cv;

        if (ng.freq>cv) {

          double b=(ng.freq-cv>=3?beta[2][size]:beta[ng.freq-cv-1][size]);

          fstar=(double)((double)(ng.freq - cv) - b)/(double)(history.freq-cv);

          lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2])
                 /
                 (double)(history.freq-cv);

          if ((size>=3 && prunesingletons()) ||
              (size==maxlevel() && prunetopsingletons()))
            //correction
            lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv);

        } else {
          // ng.freq==cv

          ng.freq>=3?suc[2]--:suc[ng.freq-1]--; //update successor stat

          fstar=0.0;
          lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2])
                 /
                 (double)(history.freq-cv);

          if ((size>=3 && prunesingletons()) ||
              (size==maxlevel() && prunetopsingletons())) //correction
            lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv);

          ng.freq>=3?suc[2]++:suc[ng.freq-1]++; //resume successor stat
        }
      } else {
        fstar=0.0;
        lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2])
               /
               (double)(history.freq-cv);

        if ((size>=3 && prunesingletons()) ||
            (size==maxlevel() && prunetopsingletons())) //correction
          lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv);

      }

      //cerr << "ngram :" << ng << "\n";


      if (*ng.wordp(1)==dict->oovcode()) {
        lambda+=fstar;
        fstar=0.0;
      } else {
        *ng.wordp(1)=dict->oovcode();
        if (get(ng,size,size)) {
          ng.freq=mfreq(ng,size);
          if ((!prunesingletons() || mfreq(ng,size)>1 || size<3) &&
              (!prunetopsingletons() || mfreq(ng,size)>1 || size<maxlevel())) {
            double b=(ng.freq>=3?beta[2][size]:beta[ng.freq-1][size]);
            lambda+=(double)(ng.freq - b)/(double)(history.freq-cv);
          }
        }
      }
    } else {
      fstar=0;
      lambda=1;
    }
  } else { // unigram case, no cross-validation

    lambda=0.0;

    int unigrtotfreq=(size<lmsize()?btotfreq():totfreq());

	
	
    if (get(ng,size,size))
      fstar=(double) mfreq(ng,size)/(double)unigrtotfreq;
    else {
			 cerr << "Missing probability for word: " << dict->decode(*ng.wordp(1)) << "\n";					
			 exit(1);
		 }
  }

  return 1;
}
コード例 #6
0
ファイル: shiftlm.cpp プロジェクト: shyamjvs/cs626_project
int mshiftbeta::train()
{
	
	trainunigr();
	
	gencorrcounts();
	gensuccstat();
	
	ngram ng(dict);
	int n1,n2,n3,n4;
	int unover3=0;
	
	oovsum=0;
	
	for (int l=1; l<=lmsize(); l++) {
		
		cerr << "level " << l << "\n";
		
		cerr << "computing statistics\n";
		
		n1=0;
		n2=0;
		n3=0,n4=0;
		
		scan(ng,INIT,l);
		
		while(scan(ng,CONT,l)) {
			
			//skip ngrams containing _OOV
			if (l>1 && ng.containsWord(dict->OOV(),l)) {
				//cerr << "skp ngram" << ng << "\n";
				continue;
			}
			
			//skip n-grams containing </s> in context
			if (l>1 && ng.containsWord(dict->EoS(),l-1)) {
				//cerr << "skp ngram" << ng << "\n";
				continue;
			}
			
			//skip 1-grams containing <s>
			if (l==1 && ng.containsWord(dict->BoS(),l)) {
				//cerr << "skp ngram" << ng << "\n";
				continue;
			}
			
			ng.freq=mfreq(ng,l);
			
			if (ng.freq==1) n1++;
			else if (ng.freq==2) n2++;
			else if (ng.freq==3) n3++;
			else if (ng.freq==4) n4++;
			if (l==1 && ng.freq >=3) unover3++;
			
		}
		
		if (l==1) {
			cerr << " n1: " << n1 << " n2: " << n2 << " n3: " << n3 << " n4: " << n4 << " unover3: " << unover3 << "\n";
		} else {
			cerr << " n1: " << n1 << " n2: " << n2 << " n3: " << n3 << " n4: " << n4 << "\n";
		}
		
		if (n1 == 0 || n2 == 0 ||  n1 <= n2) {
			cerr << "Error: lower order count-of-counts cannot be estimated properly\n";
			cerr << "Hint: use another smoothing method with this corpus.\n";
			exit(1);
		}
		
		double Y=(double)n1/(double)(n1 + 2 * n2);
		beta[0][l] = Y; //equivalent to  1 - 2 * Y * n2 / n1
		
		if (n3 ==0 || n4 == 0 || n2 <= n3 || n3 <= n4 ){
			cerr << "Warning: higher order count-of-counts cannot be estimated properly\n";
			cerr << "Fixing this problem by resorting only on the lower order count-of-counts\n";
			
			beta[1][l] = Y;
			beta[2][l] = Y;			
		}
		else{ 	  
			beta[1][l] = 2 - 3 * Y * n3 / n2; 
			beta[2][l] = 3 - 4 * Y * n4 / n3;  
		}
		
		if (beta[1][l] < 0){
			cerr << "Warning: discount coefficient is negative \n";
			cerr << "Fixing this problem by setting beta to 0 \n";			
			beta[1][l] = 0;
			
		}		
		
		
		if (beta[2][l] < 0){
			cerr << "Warning: discount coefficient is negative \n";
			cerr << "Fixing this problem by setting beta to 0 \n";			
			beta[2][l] = 0;
			
		}
				
		
		if (l==1)
			oovsum=beta[0][l] * (double) n1 + beta[1][l] * (double)n2 + beta[2][l] * (double)unover3;
		
		cerr << beta[0][l] << " " << beta[1][l] << " " << beta[2][l] << "\n";
	}
	
	return 1;
};
コード例 #7
0
ファイル: shiftlm.cpp プロジェクト: shyamjvs/cs626_project
int shiftbeta::train()
{
  ngram ng(dict);
  int n1,n2;

  trainunigr();

  beta[1]=0.0;

  for (int l=2; l<=lmsize(); l++) {

    cerr << "level " << l << "\n";
    n1=0;
    n2=0;
    scan(ng,INIT,l);
    while(scan(ng,CONT,l)) {


      if (l<lmsize()) {
        //Computing succ1 statistics for this ngram
        //to correct smoothing due to singleton pruning

        ngram hg=ng;
        get(hg,l,l);
        int s1=0;
        ngram ng2=hg;
        ng2.pushc(0);

        succscan(hg,ng2,INIT,l+1);
        while(succscan(hg,ng2,CONT,l+1)) {
          if (ng2.freq==1) s1++;
        }
        succ1(hg.link,s1);
      }

      //skip ngrams containing _OOV
      if (l>1 && ng.containsWord(dict->OOV(),l)) {
        //cerr << "skp ngram" << ng << "\n";
        continue;
      }

      //skip n-grams containing </s> in context
      if (l>1 && ng.containsWord(dict->EoS(),l-1)) {
        //cerr << "skp ngram" << ng << "\n";
        continue;
      }

      //skip 1-grams containing <s>
      if (l==1 && ng.containsWord(dict->BoS(),l)) {
        //cerr << "skp ngram" << ng << "\n";
        continue;
      }

      if (ng.freq==1) n1++;
      else if (ng.freq==2) n2++;

    }
    //compute statistics of shiftbeta smoothing
    if (beta[l]==-1) {
      if (n1>0)
        beta[l]=(double)n1/(double)(n1 + 2 * n2);
      else {
        cerr << "no singletons! \n";
        beta[l]=1.0;
      }
    }
    cerr << beta[l] << "\n";
  }

  return 1;
};