/* * add affects of document to stats * * when using multis, reset ddM.Mi[] to be totals * for training words only, ignore test words in hold */ int add_doc(int d, enum GibbsType fix) { int i, t, w, nd=0; int mi = 0; if ( ddP.bdk!=NULL ) mi = ddM.MI[d]; for (i=ddD.NdTcum[d]; i<ddD.NdTcum[d+1]; i++) { if ( fix!=GibbsHold || pctl_hold(i) ) { /* * these words are for perp. calcs */ nd++; } if ( fix!=GibbsHold || !pctl_hold(i) ) { t = Z_t(ddS.z[i]); /* * these words are for training */ ddS.Ndt[d][t]++; ddS.NdT[d]++; if ( (ddP.bdk==NULL) || Z_issetr(ddS.z[i]) ) { if ( ddP.phi==NULL ) { int val; w = ddD.w[i]; atomic_incr(ddS.NWt[t]); val = atomic_incr(ddS.Nwt[w][t]); if ( ddP.PYbeta && val==1 ) fix_tableidword(w,t); } } } if ( (ddP.bdk!=NULL) && M_multi(i) ) mi++; } if ( ddP.PYalpha ) { /* initialise ddS.Tdt[d][*] */ /* * adjust table count stats based on Ndt[d] */ for (t=0; t<ddN.T; t++) { if ( ddS.Ndt[d][t]>0 ) { int val; ddS.Tdt[d][t] = 1; val = atomic_incr(ddS.TDt[t]); if ( val==1 ) atomic_incr(ddS.TDTnz); atomic_incr(ddS.TDT); } else ddS.Tdt[d][t] = 0; } } // yap_message("add_doc(%d): %d\n", d, nd); return nd; }
/* * remove affects of document from stats */ int remove_doc(int d, enum GibbsType fix) { int i, t; for (t=0; t<ddN.T; t++) ddS.Ndt[d][t] = 0; ddS.NdT[d] = 0; if ( ddP.PYalpha ) { for (t=0; t<ddN.T; t++) if ( ddS.Tdt[d][t]>0 ) { int val; val = atomic_sub(ddS.TDt[t],ddS.Tdt[d][t]); atomic_sub(ddS.TDT,ddS.Tdt[d][t]); if ( val==0 ) { atomic_decr(ddS.TDTnz); ddS.Tlife[t] = 0; } ddS.Tdt[d][t] = 0; } } for (i=ddD.NdTcum[d]; i<ddD.NdTcum[d+1]; i++) { if ( fix!=GibbsHold || !pctl_hold(i) ) { /* * these words are for training */ if ( (ddP.bdk==NULL) || Z_issetr(ddS.z[i]) ) { t = Z_t(ddS.z[i]); if ( ddP.phi==NULL ) { int val; int w = ddD.w[i]; atomic_decr(ddS.NWt[t]); // assert(ddS.Nwt[w][t]>0); val = atomic_decr(ddS.Nwt[w][t]); if ( ddP.PYbeta ) if ( val==0 || ddS.Twt[w][t]>ddS.Nwt[w][t] ) unfix_tableidword(w,t); // ???? WRONG ?? } } } } return 0; }
/******************************** * code for LDA *****************************/ double gibbs_lda(/* * fix==GibbsNone for standard ML training/testing * fix==GibbsHold for word hold-out testing, * same as GibbsNone but also handles * train and test words differently */ enum GibbsType fix, int did, // document index int words, // do this many float *p, // temp store D_MiSi_t *dD ) { int i, wid, t, mi=0; int e; double Z, tot; double logdoc = 0; int logdocinf = 0; int StartWord = ddD.N_dTcum[did]; int EndWord = StartWord + words; float dtip[ddN.T]; #ifdef MH_STEP double doc_side_cache[ddN.T]; for (t=0; t<ddN.T; t++) doc_side_cache[t] = doc_side_fact(did,t); #endif /* * some of the latent variables are not sampled * are kept in the testing version, uses enum GibbsType * fix = global document setting * fix_doc = settings for word in this doc * * NB. if fix==GibbsNone, then fix_doc==fix * if fix==GibbsHold then fix_doc==GibbsHold or GibbsNone */ enum GibbsType fix_doc = fix; if ( PCTL_BURSTY() ) { mi = ddM.MI[did]; } e = ddD.e[did]; for (i=StartWord; i<EndWord; i++) { #ifdef MH_STEP int oldt; #endif if ( fix==GibbsHold ) { if ( pctl_hold(i) ) fix_doc = GibbsHold; // this word is a hold out else fix_doc = GibbsNone; } // check_m_vte(e); wid=ddD.w[i]; /******************* * first we remove affects of this word on the stats *******************/ #ifdef MH_STEP oldt = #endif t = Z_t(ddS.z[i]); if ( fix_doc!=GibbsHold ) { if ( remove_topic(i, did, (!PCTL_BURSTY()||Z_issetr(ddS.z[i]))?wid:-1, t, mi, dD) ) { goto endword; } } /*********************** * get topic probabilities ***********************/ // check_m_vte(e); #ifdef MU_CACHE mu_side_fact_update(e); #endif #ifdef PHI_CACHE phi_norm_update(wid, e); phi_sum_update(wid, e, i); #endif for (t=0, Z=0, tot=0; t<ddN.T; t++) { #ifdef MH_STEP int saveback = ddP.back; if ( fix_doc!=GibbsHold ) ddP.back = 0; #endif /* * (fix_doc==GibbsHold) => * doing estimation, not sampling so use prob versions * else * doing sampling so use fact versions */ #ifdef MH_STEP double tf = (fix_doc==GibbsHold)?doc_side_prob(did,t): doc_side_cache[t]; if ( tf>0 ) { double wf = (fix_doc==GibbsHold)?word_side_prob(e, wid, t): word_side_fact(e, wid, t); #else double tf = (fix_doc==GibbsHold)?doc_side_prob(did,t): doc_side_fact(did,t); if ( tf>0 ) { double wf = (fix_doc==GibbsHold)?word_side_prob(e, wid, t): word_side_fact(e, wid, t); #endif tot += tf; if ( PCTL_BURSTY() ) wf = (fix_doc==GibbsHold)?docprob(dD, t, i, mi, wf): docfact(dD, t, i, mi, wf, &dtip[t]); Z += p[t] = tf * wf; } else p[t] = 0; #ifdef MH_STEP ddP.back = saveback; #endif } if ( fix!=GibbsHold || fix_doc==GibbsHold ) logdoc += log(Z/tot); if ( logdocinf==0 ) if ( !finite(logdoc) ) { logdocinf++; yap_infinite(logdoc); } /******************* * now sample t using p[] and install affects of this on the stats; * but note this needs indicator to be set! *******************/ if ( fix_doc!=GibbsHold ) { /* * sample and update core stats */ t = samplet(p, Z, ddN.T, rng_unit(rngp)); #ifdef MH_STEP if ( t != oldt ) { double ratio = p[oldt]/p[t]; if ( PCTL_BURSTY() ) { ratio *= docfact(dD, t, i, mi, word_side_fact(e, wid, t), &dtip[t]) * doc_side_fact(did,t); ratio /= docfact(dD, oldt, i, mi, word_side_fact(e, wid, oldt), &dtip[oldt]) * doc_side_fact(did,oldt); } else { ratio *= word_side_fact(e, wid, t) * doc_side_fact(did, t); ratio /= word_side_fact(e, wid, oldt) * doc_side_fact(did, oldt); } if ( ratio<1 && ratio<rng_unit(rngp) ) t = oldt; } #endif Z_sett(ddS.z[i],t); #ifdef TRACE_WT if ( wid==TR_W && t==TR_T ) yap_message("update_topic(w=%d,t=%d,d=%d,l=%d,z=%d,N=%d,T=%d)\n", wid,t,did,i,ddS.z[i], (int)ddS.m_vte[wid][t][e],(int)ddS.s_vte[wid][t][e]); #endif update_topic(i, did, wid, t, mi, dtip[t], dD); #ifdef TRACE_WT if ( wid==TR_W && t==TR_T ) yap_message("after update_topic(w=%d,t=%d,d=%d,l=%d,z=%d,N=%d,T=%d)\n", wid,t,did,i,ddS.z[i], (int)ddS.m_vte[wid][t][e],(int)ddS.s_vte[wid][t][e]); #endif } endword: if ( PCTL_BURSTY() && M_multi(i) ) { mi++; } } return logdoc; }
/* * logic taken from core of Gibbs samples */ static void query_docprob(int did, int *mimap, float *p, D_MiSi_t *dD, float *cnt, float *wordscore) { int l, t, wid; double Z, tot; int Td_ = 0; double *tp; /* * doing estimation, not sampling so use *prob() versions * of estimates, not *fact() versions */ tp = dvec(ddN.T); if ( ddP.PYalpha ) Td_ = comp_Td(did); for (t=0; t<ddN.T; t++) tp[t] = topicprob(did,t,Td_); for (l=0; l<ddP.n_words; l++) { int cmax = 0; wid = ddP.qword[l]; if ( ddP.query[wid]!=l ) /* word has occurred before so drop */ continue; for (t=0, Z=0, tot=0; t<ddN.T; t++) { /* * doing estimation, not sampling so use prob versions */ double tf = tp[t]; if ( tf>0 ) { double wf = wordprob(wid, t); tot += tf; if ( ddP.bdk!=NULL ) { int n, s; /* * with burstiness; * reproduce some logic in docprob() but * we've got local data structures */ if ( mimap[l]>ddN.N ) { /* * doesn't occur in doc */ n = s = 0; } else if ( mimap[l]<0 ) { /* * occurs once in doc */ int z = Z_t(ddS.z[-mimap[l]-1]); n = s = (z==t)?1:0; } else { /* * its a multi */ int mii = ddM.multiind[mimap[l]]-dD->mi_base; assert(mii>=0); assert(mii<ddM.MI_max); n = dD->Mik[mii][t]; s = dD->Sik[mii][t]; } wf = (wf*(ddP.bdk[t]+ddP.ad*dD->Si[t]) + (n-ddP.ad*s))/ (ddP.bdk[t]+dD->Mi[t]); if ( cmax<n ) cmax = n; } Z += p[t] = tf*wf; } else p[t] = 0; } if ( ddP.bdk!=NULL ) cnt[l] += cmax; wordscore[l] += -log(Z/tot); } free(tp); }