Exemple #1
0
/*------------------------------------------------------------------*/
void AzTrTreeFeat::consolidateInternalWeights(const AzDvect *v_inp_w, 
                                          const AzTrTreeEnsemble_ReadOnly *ens, 
                                          AzDvect *v_w) /* output */
                                          const
{
  int f_num = featNum();  /* #features */
  v_w->reform(f_num); 

  const double *inp_w = v_inp_w->point(); 
  double *w = v_w->point_u(); 

  int dtree_num = ens->size(); 

  int fx; 
  for (fx = 0; fx < f_num; ++fx) {
    const AzTrTreeFeatInfo *fp = f_inf.point(fx); 
    if (fp->isRemoved) continue; 
    if (fp->tx < 0 || fp->tx >= dtree_num) {
      throw new AzException("AzTrTreeFeat::consolidateInternalWeights", "tree# conflict"); 
    }
    const AzTrTree_ReadOnly *my_tree = ens->tree(fp->tx); 
    int nx = fp->nx; 
    if (my_tree->node(nx)->isLeaf()) {
      int num; 
      const int *feat_no = ip_featDef.point(fp->tx, &num); 
      for ( ; ; ) {
        if (nx < 0) break; 
        if (feat_no[nx] >= 0) {
          w[fx] += inp_w[feat_no[nx]]; 
        }
        nx = my_tree->node(nx)->parent_nx; 
      }
    }
  }
}
Exemple #2
0
/*------------------------------------------------------------------*/
void AzTrTreeFeat::dump(const AzOut &out, const char *header) const
{
  if (out.isNull()) return; 
 
  AzPrint o(out); 
  o.printBegin(header); 
  o.printSw("doAllowZeroWeightLeaf", doAllowZeroWeightLeaf); 
  o.printSw("doCountRules", doCountRules); 
  o.printSw("doCheckConsistency", doCheckConsistency); 
  o.printEnd(); 
  ip_featDef.dump(dmp_out, "ip_featDef"); 
  
  int f_num = featNum(); 
  int fx; 
  for (fx = 0; fx < f_num; ++fx) {
    const AzTrTreeFeatInfo *fp = f_inf.point(fx); 
    
    o.printBegin("", ", ", "="); 
    o.inBrackets(fx); 
    o.print("isRemoved", fp->isRemoved); 
    o.print("tx", fp->tx); 
    o.print("nx", fp->nx); 
    o.print("rule length", fp->rule.length()); 
    o.printEnd(); 
  }
}
Exemple #3
0
/*------------------------------------------------------------------*/
void AzTrTreeFeat::updateMatrix(const AzDataForTrTree *data, 
                                const AzTrTreeEnsemble_ReadOnly *ens, 
                                AzBmat *b_tran) /* inout */
                                const
{
  if (data == NULL) return; 

  const char *eyec = "AzTrTreeFeat::updateMatrix"; 
  int f_num = featNum(); 
  if (ens->size() != treeNum()) {
    throw new AzException(eyec, "size of tree ensemble and #feat should be the same"); 
  }
  int old_f_num = b_tran->colNum(); 
  if (old_f_num > f_num) {
    throw new AzException(eyec, "#col is bigger than #feat"); 
  }
  int fx; 
  for (fx = 0; fx < old_f_num; ++fx) {
    if (f_inf.point(fx)->isRemoved) {
      b_tran->clear(fx); 
    }
  }

  if (old_f_num == f_num) {
    if (old_f_num == 0) {
      /* for the rare case that no feature was generated */
      b_tran->reform(data->dataNum(), f_num);  /* added 9/13/2011 */
    }
    return; 
  }

  _updateMatrix(data, ens, old_f_num, b_tran); 
}
Exemple #4
0
/*------------------------------------------------------------------*/
void AzTrTreeFeat::checkConsistency(const AzTrTreeEnsemble_ReadOnly *ens) const
{
  const char *eyec = "AzTrTreeFeat::checkConsistency"; 

  int dtree_num = ens->size(); 
  if (ip_featDef.size() != dtree_num) {
    throw new AzException(eyec, "#tree conflict"); 
  }
  int f_num = featNum(); 
  if (sp_desc.size() > 0 && sp_desc.size() != f_num) {
    throw new AzException(eyec, "#feat conflict"); 
  }

  int tx; 
  for (tx = 0; tx < dtree_num; ++tx) {
    int num; 
    const int *feat_no = ip_featDef.point(tx, &num); 
    if (num != ens->tree(tx)->nodeNum()) {
      throw new AzException(eyec, "#node conflict"); 
    }
    int nx; 
    for (nx = 0; nx < num; ++nx) {
      /*---  ip_featDef -> dtree[]  ---*/
      double tree_w = ens->tree(tx)->node(nx)->weight; 
      if (tree_w != 0 && feat_no[nx] < 0) {
        throw new AzException(eyec, "non-zero weight in the tree for non-feature?"); 
      }

      /*---  ip_featDef -> f_inf  ---*/
      if (feat_no[nx] >= 0) {
        const AzTrTreeFeatInfo *fp = f_inf.point(feat_no[nx]); 
        if (fp->isRemoved) {
          throw new AzException(eyec, "a removed feature is active in ip_featDef"); 
        }
        if (fp->tx != tx || fp->nx != nx) {
          throw new AzException(eyec, "conflict between f_inf and ip_featDef"); 
        }
      }
    }
  }

  /*---  f_inf -> ip_featDef  ---*/
  int fx; 
  for (fx = 0; fx < f_num; ++fx) {
    const AzTrTreeFeatInfo *fp = f_inf.point(fx); 
    if (fp->isRemoved) continue; 
    
    if (fp->tx < 0 || fp->tx >= dtree_num) {
      throw new AzException(eyec, "f_inf is pointing non-existing tree"); 
    }
    int num; 
    const int *feat_no = ip_featDef.point(fp->tx, &num);     
    if (fp->nx < 0 || fp->nx >= num) {
      throw new AzException(eyec, "f_inf is pointing non-existing node");       
    }
    if (feat_no[fp->nx] != fx) {
      throw new AzException(eyec, "conflict in feat# between f_inf and ip_featDef"); 
    }
  }
}
 void show(const AzOut &out, const AzIntArr *ia_fxs) const {
   int ix; 
   for (ix = 0; ix < ia_fxs->size(); ++ix) {
     int fx = ia_fxs->get(ix); 
     AzBytArr s("???"); 
     if (fx>=0 && fx<featNum()) desc(fx, &s); 
     AzPrint::writeln(out, s); 
   }
 }
 int desc2fno(const char *fnm) const {
   int fx; 
   for (fx = 0; fx < featNum(); ++fx) {
     AzBytArr s; 
     desc(fx, &s); 
     if (s.compare(fnm) == 0) {
       return fx; 
     }
   }
   return -1; 
 }
 int equals(const char *kw) const {
   int fx; 
   for (fx = 0; fx < featNum(); ++fx) {
     AzBytArr s; 
     desc(fx, &s); 
     if (s.compare(kw) == 0) {
       return fx; 
     }
   } 
   return -1; 
 }
Exemple #8
0
/*------------------------------------------------------------------*/
void AzTrTreeFeat::getWeight(const AzTrTreeEnsemble_ReadOnly *ens, 
                         AzDvect *v_w) const /* output */
{
  const char *eyec = "AzTrTreeFeat::getWeight"; 
  int dtree_num = ens->size();
  if (dtree_num != treeNum()) {
    throw new AzException(eyec, "#trees conflict"); 
  }
  int feat_num = featNum(); 
  v_w->reform(feat_num); 

  int tx; 
  for (tx = 0; tx < dtree_num; ++tx) {
    _getWeight(ens->tree(tx), tx, v_w);    
  }
}
Exemple #9
0
/*------------------------------------------------------------------*/
void AzTrTreeFeat::_updateMatrix(const AzDataForTrTree *data, 
                          const AzTrTreeEnsemble_ReadOnly *ens, 
                          int old_f_num, 
                          /*---  output  ---*/
                          AzBmat *b_tran) const
{
  int data_num = data->dataNum(); 
  int f_num = featNum(); 
  if (old_f_num == 0) {
    b_tran->reform(data_num, f_num); 
  }
  else {
    if (b_tran->rowNum() != data_num || 
        b_tran->colNum() != old_f_num) {
      throw new AzException("AzTrTreeFeat::_updateMatrix", 
                            "b_tran has a wrong shape"); 
    }
    b_tran->resize(f_num); 
  }

  /*---  which trees are referred in the new features?  ---*/
  AzIntArr ia_tx; 
  int fx; 
  for (fx = old_f_num; fx < f_num; ++fx) {
    ia_tx.put(f_inf.point(fx)->tx); 
  }
  ia_tx.unique(); /* remove duplication */
  int tx_num; 
  const int *txs = ia_tx.point(&tx_num); 

  /*---  generate features  ---*/
  AzDataArray<AzIntArr> aia_fx_dx(f_num-old_f_num); 
  int xx; 
  for (xx = 0; xx < tx_num; ++xx) {
    int tx = txs[xx]; 
    int dx; 
    for (dx = 0; dx < data_num; ++dx) {
      genFeats(ens->tree(tx), tx, data, dx, 
               old_f_num, &aia_fx_dx); 
    }
  }

  /*---  load into the matrix  ---*/
  for (fx = old_f_num; fx < f_num; ++fx) {
    b_tran->load(fx, aia_fx_dx.point(fx-old_f_num)); 
  }  
}
 void contains(const AzStrArray *sp_kw, 
                 AzIntArr *ia_fxs) const {
   ia_fxs->reset(); 
   if (sp_kw->size()==0) return; 
   int fx; 
   for (fx = 0; fx < featNum(); ++fx) {
     AzBytArr s; 
     desc(fx, &s); 
     int ix; 
     for (ix = 0; ix < sp_kw->size(); ++ix) {
       if (s.contains(sp_kw->c_str(ix))) {
         ia_fxs->put(fx); 
         break; 
       }
     }
   } 
 }
Exemple #11
0
/*------------------------------------------------------------------*/
int AzTrTreeFeat::countNonzeroNodup(const AzDvect *v_w) const
{
  const char *eyec = "AzTrTreeFeat::countNonzeroNodup"; 

  if (!doCountRules) {
    return -1; 
  }

  if (v_w->rowNum() != featNum()) {
    throw new AzException(eyec, "#feat conflict"); 
  }

  AzIFarr ifa_rx_zerocount; 

  int f_num = v_w->rowNum(); 
  const double *w = v_w->point(); 
  int fx; 
  for (fx = 0; fx < f_num; ++fx) {
    const AzTrTreeFeatInfo *fp = f_inf.point(fx); 
    if (fp->isRemoved) {
      continue; 
    }
    if (w[fx] == 0) {
      int rule_idx = pool_rules.find(&fp->rule); 
      if (rule_idx < 0) {
        throw new AzException(eyec, "rule not found in the pool?!"); 
      }
      ifa_rx_zerocount.put(rule_idx, 1); 
    }
  }

  int nz_f_num = pool_rules.size(); 
  ifa_rx_zerocount.squeeze_Sum(); 
  int num = ifa_rx_zerocount.size(); 
  int ix; 
  for (ix = 0; ix < num; ++ix) {
    int rule_idx; 
    int zero_count = (int)ifa_rx_zerocount.get(ix, &rule_idx); 
    if (zero_count >= pool_rules.getCount(rule_idx)) {
      --nz_f_num; 
    }
  }
  return nz_f_num; 
}