int mixture::loadpar(char* ipf) { mfstream inp(ipf,ios::in); if (!inp) { std::stringstream ss_msg; ss_msg << "cannot open file: " << ipf; exit_error(IRSTLM_ERROR_IO, ss_msg.str()); } cerr << "loading parameters from " << ipf << "\n"; // check compatibility char header[100]; inp.getline(header,100); int value1,value2; sscanf(header,"%d %d",&value1,&value2); if (value1 != lmsize() || value2 != pmax) { std::stringstream ss_msg; ss_msg << "parameter file " << ipf << " is incompatible"; exit_error(IRSTLM_ERROR_DATA, ss_msg.str()); } for (int i=0; i<=lmsize(); i++) for (int j=0; j<pmax; j++) inp.readx(l[i][j],sizeof(double),numslm); return 1; }
int mixture::savepar(char* opf) { mfstream out(opf,ios::out); cerr << "saving parameters in " << opf << "\n"; out << lmsize() << " " << pmax << "\n"; for (int i=0; i<=lmsize(); i++) for (int j=0; j<pmax; j++) out.writex(l[i][j],sizeof(double),numslm); return 1; }
shiftbeta::shiftbeta(char* ngtfile,int depth,int prunefreq,double b,TABLETYPE tt): mdiadaptlm(ngtfile,depth,tt) { cerr << "Creating LM with ShiftBeta smoothing\n"; if (b==-1.0 || (b < 1.0 && b >0.0)) { beta=new double[lmsize()+1]; for (int l=lmsize(); l>1; l--) beta[l]=b; } else { cerr << "shiftbeta: beta must be < 1.0 and > 0\n"; exit (1); } prunethresh=prunefreq; cerr << "PruneThresh: " << prunethresh << "\n"; };
int mixture::train() { double zf; srand(1333); genpmap(); if (dub()<dict->size()) { std::stringstream ss_msg; ss_msg << "\nERROR: DUB value is too small: the LM will possibly compute wrong probabilities if sub-LMs have different vocabularies!\n"; ss_msg << "This exception should already have been handled before!!!\n"; exit_error(IRSTLM_ERROR_MODEL, ss_msg.str()); } cerr << "mixlm --> DUB: " << dub() << endl; for (int i=0; i<numslm; i++) { cerr << i << " sublm --> DUB: " << sublm[i]->dub() << endl; cerr << "eventually generate OOV code "; cerr << sublm[i]->dict->encode(sublm[i]->dict->OOV()) << "\n"; sublm[i]->train(); } //initialize parameters for (int i=0; i<=lmsize(); i++) { l[i]=new double*[pmax]; for (int j=0; j<pmax; j++) { l[i][j]=new double[numslm]; for (int k=0; k<numslm; k++) l[i][j][k]=1.0/(double)numslm; } } if (ipfname) { //load parameters from file loadpar(ipfname); } else { //start training of mixture model double oldl[pmax][numslm]; char alive[pmax],used[pmax]; int totalive; ngram ng(sublm[0]->dict); for (int lev=1; lev<=lmsize(); lev++) { zf=sublm[0]->zerofreq(lev); cerr << "Starting training at lev:" << lev << "\n"; for (int i=0; i<pmax; i++) { alive[i]=1; used[i]=0; } totalive=1; int iter=0; while (totalive && (iter < 20) ) { iter++; for (int i=0; i<pmax; i++) if (alive[i]) for (int j=0; j<numslm; j++) { oldl[i][j]=l[lev][i][j]; l[lev][i][j]=1.0/(double)numslm; } sublm[0]->scan(ng,INIT,lev); while(sublm[0]->scan(ng,CONT,lev)) { //do not include oov for unigrams if ((lev==1) && (*ng.wordp(1)==sublm[0]->dict->oovcode())) continue; int par=pmap(ng,lev); used[par]=1; //controllo se aggiornare il parametro if (alive[par]) { double backoff=(lev>1?prob(ng,lev-1):1); //backoff double denom=0.0; double* numer = new double[numslm]; double fstar,lambda; //int cv=(int)floor(zf * (double)ng.freq + rand01()); //int cv=1; //old version of leaving-one-out int cv=(int)floor(zf * (double)ng.freq)+1; //int cv=1; //old version of leaving-one-out //if (lev==3)q //if (iter>10) // cout << ng // << " backoff " << backoff // << " level " << lev // << "\n"; for (int i=0; i<numslm; i++) { //use cv if i=0 sublm[i]->discount(ng,lev,fstar,lambda,(i==0)*(cv)); numer[i]=oldl[par][i]*(fstar + lambda * backoff); ngram ngslm(sublm[i]->dict); ngslm.trans(ng); if ((*ngslm.wordp(1)==sublm[i]->dict->oovcode()) && (dict->dub() > sublm[i]->dict->size())) numer[i]/=(double)(dict->dub() - sublm[i]->dict->size()); denom+=numer[i]; } for (int i=0; i<numslm; i++) { l[lev][par][i]+=(ng.freq * (numer[i]/denom)); //if (iter>10) //cout << ng << " l: " << l[lev][par][i] << "\n"; } delete []numer; } } //normalize all parameters totalive=0; for (int i=0; i<pmax; i++) { double tot=0; if (alive[i]) { for (int j=0; j<numslm; j++) tot+=(l[lev][i][j]); for (int j=0; j<numslm; j++) l[lev][i][j]/=tot; //decide if to continue to update if (!used[i] || (reldist(l[lev][i],oldl[i],numslm)<=0.05)) alive[i]=0; } totalive+=alive[i]; } cerr << "Lev " << lev << " iter " << iter << " tot alive " << totalive << "\n"; } } } if (opfname) savepar(opfname); return 1; }
int mshiftbeta::discount(ngram ng_,int size,double& fstar,double& lambda, int cv) { ngram ng(dict); ng.trans(ng_); //cout << "size :" << size << " " << ng <<"\n"; if (size > 1) { ngram history=ng; //singleton pruning only on real counts!! if (ng.ckhisto(size) && get(history,size,size-1) && (history.freq > cv) && ((size < 3) || ((history.freq-cv) > prunethresh ))) { // no history pruning with corrected counts! int suc[3]; suc[0]=succ1(history.link); suc[1]=succ2(history.link); suc[2]=history.succ-suc[0]-suc[1]; if (get(ng,size,size) && (!prunesingletons() || mfreq(ng,size)>1 || size<3) && (!prunetopsingletons() || mfreq(ng,size)>1 || size<maxlevel())) { ng.freq=mfreq(ng,size); cv=(cv>ng.freq)?ng.freq:cv; if (ng.freq>cv) { double b=(ng.freq-cv>=3?beta[2][size]:beta[ng.freq-cv-1][size]); fstar=(double)((double)(ng.freq - cv) - b)/(double)(history.freq-cv); lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2]) / (double)(history.freq-cv); if ((size>=3 && prunesingletons()) || (size==maxlevel() && prunetopsingletons())) //correction lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv); } else { // ng.freq==cv ng.freq>=3?suc[2]--:suc[ng.freq-1]--; //update successor stat fstar=0.0; lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2]) / (double)(history.freq-cv); if ((size>=3 && prunesingletons()) || (size==maxlevel() && prunetopsingletons())) //correction lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv); ng.freq>=3?suc[2]++:suc[ng.freq-1]++; //resume successor stat } } else { fstar=0.0; lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2]) / (double)(history.freq-cv); if ((size>=3 && prunesingletons()) || (size==maxlevel() && prunetopsingletons())) //correction lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv); } //cerr << "ngram :" << ng << "\n"; if (*ng.wordp(1)==dict->oovcode()) { lambda+=fstar; fstar=0.0; } else { *ng.wordp(1)=dict->oovcode(); if (get(ng,size,size)) { ng.freq=mfreq(ng,size); if ((!prunesingletons() || mfreq(ng,size)>1 || size<3) && (!prunetopsingletons() || mfreq(ng,size)>1 || size<maxlevel())) { double b=(ng.freq>=3?beta[2][size]:beta[ng.freq-1][size]); lambda+=(double)(ng.freq - b)/(double)(history.freq-cv); } } } } else { fstar=0; lambda=1; } } else { // unigram case, no cross-validation lambda=0.0; int unigrtotfreq=(size<lmsize()?btotfreq():totfreq()); if (get(ng,size,size)) fstar=(double) mfreq(ng,size)/(double)unigrtotfreq; else { cerr << "Missing probability for word: " << dict->decode(*ng.wordp(1)) << "\n"; exit(1); } } return 1; }
int mshiftbeta::train() { trainunigr(); gencorrcounts(); gensuccstat(); ngram ng(dict); int n1,n2,n3,n4; int unover3=0; oovsum=0; for (int l=1; l<=lmsize(); l++) { cerr << "level " << l << "\n"; cerr << "computing statistics\n"; n1=0; n2=0; n3=0,n4=0; scan(ng,INIT,l); while(scan(ng,CONT,l)) { //skip ngrams containing _OOV if (l>1 && ng.containsWord(dict->OOV(),l)) { //cerr << "skp ngram" << ng << "\n"; continue; } //skip n-grams containing </s> in context if (l>1 && ng.containsWord(dict->EoS(),l-1)) { //cerr << "skp ngram" << ng << "\n"; continue; } //skip 1-grams containing <s> if (l==1 && ng.containsWord(dict->BoS(),l)) { //cerr << "skp ngram" << ng << "\n"; continue; } ng.freq=mfreq(ng,l); if (ng.freq==1) n1++; else if (ng.freq==2) n2++; else if (ng.freq==3) n3++; else if (ng.freq==4) n4++; if (l==1 && ng.freq >=3) unover3++; } if (l==1) { cerr << " n1: " << n1 << " n2: " << n2 << " n3: " << n3 << " n4: " << n4 << " unover3: " << unover3 << "\n"; } else { cerr << " n1: " << n1 << " n2: " << n2 << " n3: " << n3 << " n4: " << n4 << "\n"; } if (n1 == 0 || n2 == 0 || n1 <= n2) { cerr << "Error: lower order count-of-counts cannot be estimated properly\n"; cerr << "Hint: use another smoothing method with this corpus.\n"; exit(1); } double Y=(double)n1/(double)(n1 + 2 * n2); beta[0][l] = Y; //equivalent to 1 - 2 * Y * n2 / n1 if (n3 ==0 || n4 == 0 || n2 <= n3 || n3 <= n4 ){ cerr << "Warning: higher order count-of-counts cannot be estimated properly\n"; cerr << "Fixing this problem by resorting only on the lower order count-of-counts\n"; beta[1][l] = Y; beta[2][l] = Y; } else{ beta[1][l] = 2 - 3 * Y * n3 / n2; beta[2][l] = 3 - 4 * Y * n4 / n3; } if (beta[1][l] < 0){ cerr << "Warning: discount coefficient is negative \n"; cerr << "Fixing this problem by setting beta to 0 \n"; beta[1][l] = 0; } if (beta[2][l] < 0){ cerr << "Warning: discount coefficient is negative \n"; cerr << "Fixing this problem by setting beta to 0 \n"; beta[2][l] = 0; } if (l==1) oovsum=beta[0][l] * (double) n1 + beta[1][l] * (double)n2 + beta[2][l] * (double)unover3; cerr << beta[0][l] << " " << beta[1][l] << " " << beta[2][l] << "\n"; } return 1; };
int shiftbeta::train() { ngram ng(dict); int n1,n2; trainunigr(); beta[1]=0.0; for (int l=2; l<=lmsize(); l++) { cerr << "level " << l << "\n"; n1=0; n2=0; scan(ng,INIT,l); while(scan(ng,CONT,l)) { if (l<lmsize()) { //Computing succ1 statistics for this ngram //to correct smoothing due to singleton pruning ngram hg=ng; get(hg,l,l); int s1=0; ngram ng2=hg; ng2.pushc(0); succscan(hg,ng2,INIT,l+1); while(succscan(hg,ng2,CONT,l+1)) { if (ng2.freq==1) s1++; } succ1(hg.link,s1); } //skip ngrams containing _OOV if (l>1 && ng.containsWord(dict->OOV(),l)) { //cerr << "skp ngram" << ng << "\n"; continue; } //skip n-grams containing </s> in context if (l>1 && ng.containsWord(dict->EoS(),l-1)) { //cerr << "skp ngram" << ng << "\n"; continue; } //skip 1-grams containing <s> if (l==1 && ng.containsWord(dict->BoS(),l)) { //cerr << "skp ngram" << ng << "\n"; continue; } if (ng.freq==1) n1++; else if (ng.freq==2) n2++; } //compute statistics of shiftbeta smoothing if (beta[l]==-1) { if (n1>0) beta[l]=(double)n1/(double)(n1 + 2 * n2); else { cerr << "no singletons! \n"; beta[l]=1.0; } } cerr << beta[l] << "\n"; } return 1; };