/*------------------------------------------------------------------*/ void AzTrTreeFeat::consolidateInternalWeights(const AzDvect *v_inp_w, const AzTrTreeEnsemble_ReadOnly *ens, AzDvect *v_w) /* output */ const { int f_num = featNum(); /* #features */ v_w->reform(f_num); const double *inp_w = v_inp_w->point(); double *w = v_w->point_u(); int dtree_num = ens->size(); int fx; for (fx = 0; fx < f_num; ++fx) { const AzTrTreeFeatInfo *fp = f_inf.point(fx); if (fp->isRemoved) continue; if (fp->tx < 0 || fp->tx >= dtree_num) { throw new AzException("AzTrTreeFeat::consolidateInternalWeights", "tree# conflict"); } const AzTrTree_ReadOnly *my_tree = ens->tree(fp->tx); int nx = fp->nx; if (my_tree->node(nx)->isLeaf()) { int num; const int *feat_no = ip_featDef.point(fp->tx, &num); for ( ; ; ) { if (nx < 0) break; if (feat_no[nx] >= 0) { w[fx] += inp_w[feat_no[nx]]; } nx = my_tree->node(nx)->parent_nx; } } } }
/*------------------------------------------------------------------*/ void AzTrTreeFeat::dump(const AzOut &out, const char *header) const { if (out.isNull()) return; AzPrint o(out); o.printBegin(header); o.printSw("doAllowZeroWeightLeaf", doAllowZeroWeightLeaf); o.printSw("doCountRules", doCountRules); o.printSw("doCheckConsistency", doCheckConsistency); o.printEnd(); ip_featDef.dump(dmp_out, "ip_featDef"); int f_num = featNum(); int fx; for (fx = 0; fx < f_num; ++fx) { const AzTrTreeFeatInfo *fp = f_inf.point(fx); o.printBegin("", ", ", "="); o.inBrackets(fx); o.print("isRemoved", fp->isRemoved); o.print("tx", fp->tx); o.print("nx", fp->nx); o.print("rule length", fp->rule.length()); o.printEnd(); } }
/*------------------------------------------------------------------*/ void AzTrTreeFeat::updateMatrix(const AzDataForTrTree *data, const AzTrTreeEnsemble_ReadOnly *ens, AzBmat *b_tran) /* inout */ const { if (data == NULL) return; const char *eyec = "AzTrTreeFeat::updateMatrix"; int f_num = featNum(); if (ens->size() != treeNum()) { throw new AzException(eyec, "size of tree ensemble and #feat should be the same"); } int old_f_num = b_tran->colNum(); if (old_f_num > f_num) { throw new AzException(eyec, "#col is bigger than #feat"); } int fx; for (fx = 0; fx < old_f_num; ++fx) { if (f_inf.point(fx)->isRemoved) { b_tran->clear(fx); } } if (old_f_num == f_num) { if (old_f_num == 0) { /* for the rare case that no feature was generated */ b_tran->reform(data->dataNum(), f_num); /* added 9/13/2011 */ } return; } _updateMatrix(data, ens, old_f_num, b_tran); }
/*------------------------------------------------------------------*/ void AzTrTreeFeat::checkConsistency(const AzTrTreeEnsemble_ReadOnly *ens) const { const char *eyec = "AzTrTreeFeat::checkConsistency"; int dtree_num = ens->size(); if (ip_featDef.size() != dtree_num) { throw new AzException(eyec, "#tree conflict"); } int f_num = featNum(); if (sp_desc.size() > 0 && sp_desc.size() != f_num) { throw new AzException(eyec, "#feat conflict"); } int tx; for (tx = 0; tx < dtree_num; ++tx) { int num; const int *feat_no = ip_featDef.point(tx, &num); if (num != ens->tree(tx)->nodeNum()) { throw new AzException(eyec, "#node conflict"); } int nx; for (nx = 0; nx < num; ++nx) { /*--- ip_featDef -> dtree[] ---*/ double tree_w = ens->tree(tx)->node(nx)->weight; if (tree_w != 0 && feat_no[nx] < 0) { throw new AzException(eyec, "non-zero weight in the tree for non-feature?"); } /*--- ip_featDef -> f_inf ---*/ if (feat_no[nx] >= 0) { const AzTrTreeFeatInfo *fp = f_inf.point(feat_no[nx]); if (fp->isRemoved) { throw new AzException(eyec, "a removed feature is active in ip_featDef"); } if (fp->tx != tx || fp->nx != nx) { throw new AzException(eyec, "conflict between f_inf and ip_featDef"); } } } } /*--- f_inf -> ip_featDef ---*/ int fx; for (fx = 0; fx < f_num; ++fx) { const AzTrTreeFeatInfo *fp = f_inf.point(fx); if (fp->isRemoved) continue; if (fp->tx < 0 || fp->tx >= dtree_num) { throw new AzException(eyec, "f_inf is pointing non-existing tree"); } int num; const int *feat_no = ip_featDef.point(fp->tx, &num); if (fp->nx < 0 || fp->nx >= num) { throw new AzException(eyec, "f_inf is pointing non-existing node"); } if (feat_no[fp->nx] != fx) { throw new AzException(eyec, "conflict in feat# between f_inf and ip_featDef"); } } }
void show(const AzOut &out, const AzIntArr *ia_fxs) const { int ix; for (ix = 0; ix < ia_fxs->size(); ++ix) { int fx = ia_fxs->get(ix); AzBytArr s("???"); if (fx>=0 && fx<featNum()) desc(fx, &s); AzPrint::writeln(out, s); } }
int desc2fno(const char *fnm) const { int fx; for (fx = 0; fx < featNum(); ++fx) { AzBytArr s; desc(fx, &s); if (s.compare(fnm) == 0) { return fx; } } return -1; }
int equals(const char *kw) const { int fx; for (fx = 0; fx < featNum(); ++fx) { AzBytArr s; desc(fx, &s); if (s.compare(kw) == 0) { return fx; } } return -1; }
/*------------------------------------------------------------------*/ void AzTrTreeFeat::getWeight(const AzTrTreeEnsemble_ReadOnly *ens, AzDvect *v_w) const /* output */ { const char *eyec = "AzTrTreeFeat::getWeight"; int dtree_num = ens->size(); if (dtree_num != treeNum()) { throw new AzException(eyec, "#trees conflict"); } int feat_num = featNum(); v_w->reform(feat_num); int tx; for (tx = 0; tx < dtree_num; ++tx) { _getWeight(ens->tree(tx), tx, v_w); } }
/*------------------------------------------------------------------*/ void AzTrTreeFeat::_updateMatrix(const AzDataForTrTree *data, const AzTrTreeEnsemble_ReadOnly *ens, int old_f_num, /*--- output ---*/ AzBmat *b_tran) const { int data_num = data->dataNum(); int f_num = featNum(); if (old_f_num == 0) { b_tran->reform(data_num, f_num); } else { if (b_tran->rowNum() != data_num || b_tran->colNum() != old_f_num) { throw new AzException("AzTrTreeFeat::_updateMatrix", "b_tran has a wrong shape"); } b_tran->resize(f_num); } /*--- which trees are referred in the new features? ---*/ AzIntArr ia_tx; int fx; for (fx = old_f_num; fx < f_num; ++fx) { ia_tx.put(f_inf.point(fx)->tx); } ia_tx.unique(); /* remove duplication */ int tx_num; const int *txs = ia_tx.point(&tx_num); /*--- generate features ---*/ AzDataArray<AzIntArr> aia_fx_dx(f_num-old_f_num); int xx; for (xx = 0; xx < tx_num; ++xx) { int tx = txs[xx]; int dx; for (dx = 0; dx < data_num; ++dx) { genFeats(ens->tree(tx), tx, data, dx, old_f_num, &aia_fx_dx); } } /*--- load into the matrix ---*/ for (fx = old_f_num; fx < f_num; ++fx) { b_tran->load(fx, aia_fx_dx.point(fx-old_f_num)); } }
void contains(const AzStrArray *sp_kw, AzIntArr *ia_fxs) const { ia_fxs->reset(); if (sp_kw->size()==0) return; int fx; for (fx = 0; fx < featNum(); ++fx) { AzBytArr s; desc(fx, &s); int ix; for (ix = 0; ix < sp_kw->size(); ++ix) { if (s.contains(sp_kw->c_str(ix))) { ia_fxs->put(fx); break; } } } }
/*------------------------------------------------------------------*/ int AzTrTreeFeat::countNonzeroNodup(const AzDvect *v_w) const { const char *eyec = "AzTrTreeFeat::countNonzeroNodup"; if (!doCountRules) { return -1; } if (v_w->rowNum() != featNum()) { throw new AzException(eyec, "#feat conflict"); } AzIFarr ifa_rx_zerocount; int f_num = v_w->rowNum(); const double *w = v_w->point(); int fx; for (fx = 0; fx < f_num; ++fx) { const AzTrTreeFeatInfo *fp = f_inf.point(fx); if (fp->isRemoved) { continue; } if (w[fx] == 0) { int rule_idx = pool_rules.find(&fp->rule); if (rule_idx < 0) { throw new AzException(eyec, "rule not found in the pool?!"); } ifa_rx_zerocount.put(rule_idx, 1); } } int nz_f_num = pool_rules.size(); ifa_rx_zerocount.squeeze_Sum(); int num = ifa_rx_zerocount.size(); int ix; for (ix = 0; ix < num; ++ix) { int rule_idx; int zero_count = (int)ifa_rx_zerocount.get(ix, &rule_idx); if (zero_count >= pool_rules.getCount(rule_idx)) { --nz_f_num; } } return nz_f_num; }