Exemplo n.º 1
0
/*-------------------------------------------------------------------*/
void AzRgforest::cold_start(const char *param,
                            const AzSmat *m_x,
                            const AzDvect *v_y,
                            const AzSvFeatInfo *featInfo,
                            const AzDvect *v_fixed_dw,
                            const AzOut &out_req)
{
    out = out_req;
    s_config.reset(param);

    AzParam az_param(param);
    int max_tree_num = resetParam(az_param);
    setInput(az_param, m_x, featInfo);
    reg_depth->reset(az_param, out);  /* init regularizer on node depth */
    v_p.reform(v_y->rowNum());
    opt->cold_start(loss_type, data, reg_depth, /* initialize optimizer */
                    az_param, v_y, v_fixed_dw, out, &v_p);
    initTarget(v_y, v_fixed_dw);
    initEnsemble(az_param, max_tree_num); /* initialize tree ensemble */
    fs->reset(az_param, reg_depth, out); /* initialize node search */
    az_param.check(out);
    l_num = 0; /* initialize leaf node counter */

    if (!beVerbose) {
        out.deactivate(); /* shut up after printing everyone's config */
    }

    time_init(); /* initialize time measure ment */
    end_of_initialization();
}
Exemplo n.º 2
0
/*--------------------------------------------------------*/
void AzsSvrg::train_test_classif(const char *param, 
                         const AzDSmat *_m_trn_x, const AzIntArr *_ia_trn_lab, 
                         const AzDSmat *_m_tst_x, const AzIntArr *_ia_tst_lab, 
                         int _class_num)
{
  /*---  set data info into class variables so that everyone can see ... ---*/
  m_trn_x = _m_trn_x; ia_trn_lab = _ia_trn_lab; 
  m_tst_x = _m_tst_x; ia_tst_lab = _ia_tst_lab; 
  v_trn_y = v_tst_y = NULL; 
 
  class_num = _class_num; 
  if (class_num == 2) {
    AzTimeLog::print("Binary classification ... ", log_out); 
    class_num = 1; 
  }
  
  /*---  parse parameters  ---*/
  AzParam azp(param); 
  resetParam(azp); 
  printParam(log_out); 
  azp.check(log_out); 

  /*---  training and testing  ---*/  
  _train_test(); 
}
Exemplo n.º 3
0
  virtual void reset_data(const AzOut &out, 
                  const AzSmat *m_data, 
                  AzParam &p, 
                  bool beTight, 
                  const AzSvFeatInfo *inp_feat=NULL)
  {
    resetParam(p); 
    printParam(out); 

    /*---  count nonzero components  ---*/
    double nz_ratio; 
    m_data->nonZeroNum(&nz_ratio); 
    AzBytArr s("Training data: "); 
    s.cn(m_data->rowNum());s.c("x");s.cn(m_data->colNum()); 
    s.c(", nonzero_ratio=", nz_ratio, 4); 

    /*---  decide sparse or dense  ---*/
    AzBytArr s_dp("; managed as dense data"); 
    bool doSparse = false; 
    if (dataproc == dataproc_Auto && 
        nz_ratio < Az_nz_ratio_threshold || 
        dataproc == dataproc_Sparse) { 
      doSparse = true; 
      s_dp.reset("; managed as sparse data"); 
    }
    if (dataproc != dataproc_Auto) s_dp.concat(" as requested."); 
    else                           s_dp.concat("."); 
    AzPrint::writeln(out, "-------------"); 
    AzPrint::writeln(out, s, s_dp); 
    AzPrint::writeln(out, "-------------"); 

   /*---  pre-sort data  ---*/
    m_tran_sparse.reset(); 
    m_tran_dense.unlock(); 
    m_tran_dense.reset(); 
    data_num = m_data->colNum(); 
    if (doSparse) {
      m_data->transpose(&m_tran_sparse); 
      sorted_arr.reset_sparse(&m_tran_sparse, beTight); 
    }
    else {
      m_tran_dense.transpose_from(m_data); 
      sorted_arr.reset_dense(&m_tran_dense, beTight); 
      /* prohibit any action to change the pointers to the column vectors */
      m_tran_dense.lock(); 
    }
    if (inp_feat != NULL) {
      feat.reset(inp_feat); 
      if (feat.featNum() != m_data->rowNum()) {
        throw new AzException(AzInputError, "AzDataForTrTree::reset", "#feat mismatch"); 
      }
    }
    else {
      feat.reset(m_data->rowNum()); 
    }
  }
Exemplo n.º 4
0
/*-------------------------------------------------------------------*/
void AzRgforest::warm_start(const char *param,
                            const AzSmat *m_x,
                            const AzDvect *v_y,
                            const AzSvFeatInfo *featInfo,
                            const AzDvect *v_fixed_dw,
                            const AzTreeEnsemble *inp_ens,
                            const AzOut &out_req)
{
    const char *eyec = "AzRgforest::warm_start";
    out = out_req;
    if (inp_ens->orgdim() > 0 &&
            inp_ens->orgdim() != m_x->rowNum()) {
        AzBytArr s("Mismatch in feature dimensionality.  ");
        s.cn(inp_ens->orgdim());
        s.c(" (tree ensemeble), ");
        s.cn(m_x->rowNum());
        s.c(" (dataset).");
        throw new AzException(AzInputError, eyec, s.c_str());
    }
    s_config.reset(param);

    AzParam az_param(param);
    int max_tree_num = resetParam(az_param);

    warmup_timer(inp_ens, max_tree_num); /* timers are modified for warm-start */

    setInput(az_param, m_x, featInfo);

    AzTimeLog::print("Warming-up trees ... ", log_out);
    warmupEnsemble(az_param, max_tree_num, inp_ens); /* v_p is set */

    reg_depth->reset(az_param, out);  /* init regularizer on node depth */

    AzTimeLog::print("Warming-up the optimizer ... ", log_out);
    opt->warm_start(loss_type, data, reg_depth, /* initialize optimizer */
                    az_param, v_y, v_fixed_dw, out,
                    ens, &v_p);

    initTarget(v_y, v_fixed_dw);

    fs->reset(az_param, reg_depth, out); /* initialize node search */
    az_param.check(out);
    l_num = ens->leafNum();  /* warm-up #leaf */

    if (!beVerbose) {
        out.deactivate(); /* shut up after printing everyone's config */
    }

    time_init(); /* initialize time measure ment */
    end_of_initialization();
    AzTimeLog::print("End of warming-up ... ", log_out);
}
Exemplo n.º 5
0
/*--------------------------------------------------------*/
void AzOptOnTree::reset(AzLossType l_type, 
                        const AzDvect *inp_v_y, 
                        const AzDvect *inp_v_fixed_dw, /* user-assigned data point weights */
                        const AzRegDepth *inp_reg_depth, 
                        AzParam &param, 
                        bool beVerbose, 
                        const AzOut out_req, 
                        /*---  for warm start  ---*/
                        const AzTrTreeEnsemble_ReadOnly *inp_ens, 
                        const AzTrTreeFeat *inp_tree_feat, 
                        const AzDvect *inp_v_p)

{
  _reset(); 

  reg_depth = inp_reg_depth; 
  out = out_req; 
  my_dmp_out = dmp_out; 

  v_y.set(inp_v_y); 
  if (!AzDvect::isNull(inp_v_fixed_dw)) {
    v_fixed_dw.set(inp_v_fixed_dw); 
  }
  v_p.reform(v_y.rowNum()); 
  loss_type = l_type; 

  resetParam(param); 

  if (max_delta < 0 && 
      AzLoss::isExpoFamily(loss_type)) {
    max_delta = 1; 
  }

  printParam(out, beVerbose); 
  checkParam(); 

  if (!beVerbose) {
    out.deactivate();  
    my_dmp_out.deactivate(); 
  }

  var_const = 0; 
  fixed_const = 0; 
  if (doUseAvg) {
    fixed_const = v_y.sum() / (double)v_y.rowNum(); 
    v_p.set(fixed_const); 
  }

  if (inp_ens != NULL) {
    _warmup(inp_ens, inp_tree_feat, inp_v_p); 
  }
}
Exemplo n.º 6
0
 virtual void reset(int argc, const char *argv[], const AzOut &out) {
   if (argc < 2) {
     if (argc >= 1) {
       AzPrint::writeln(out, ""); 
       AzPrint::writeln(out, argv[0]); /* action */
     }    
     printHelp(out); 
     throw new AzException(AzNormal, "", ""); 
   }
   const char *action = argv[0]; 
   AzParam azp(param_dlm, argc-1, argv+1); 
   AzPrint::writeln(out, ""); 
   AzPrint::writeln(out, "--------------------------------------------------");     
   AzPrint::writeln(out, action, " ", azp.c_str()); 
   AzPrint::writeln(out, "--------------------------------------------------"); 
   resetParam(azp); 
   printParam(out); 
   azp.check(out);    
 }
Exemplo n.º 7
0
/*--------------------------------------------------------*/
void AzsSvrg::train_test_regress(const char *param, 
                         const AzDSmat *_m_trn_x, const AzDvect *_v_trn_y, 
                         const AzDSmat *_m_tst_x, const AzDvect *_v_tst_y)
{
  /*---  set data info into class variables so that everyone can see ... ---*/
  m_trn_x = _m_trn_x; v_trn_y = _v_trn_y; 
  m_tst_x = _m_tst_x; v_tst_y = _v_tst_y; 
  ia_trn_lab = ia_tst_lab = NULL; 
 
  class_num = 1; 
  AzTimeLog::print("Regression ... ", log_out); 

  /*---  parse parameters  ---*/
  AzParam azp(param); 
  resetParam(azp); 
  printParam(log_out); 
  azp.check(log_out); 

  /*---  training and testing  ---*/
  _train_test(); 
}
Exemplo n.º 8
0
/*------------------------------------------------------------------*/
void AzTrTreeFeat::reset(const AzDataForTrTree *data, 
                         AzParam &param, 
                         const AzOut &out_req, 
                         bool inp_doAllowZeroWeightLeaf)
{
  out = out_req; 
  org_featInfo.reset(data->featInfo()); 
  ip_featDef.reset(); 
  int init_size=10000, avg_len=32; 
  pool_rules.reset(init_size, avg_len); 
  pool_rules_rmved.reset(init_size, avg_len); 
  sp_desc.reset(init_size, avg_len); 
  f_inf.reset(); 
  doAllowZeroWeightLeaf = inp_doAllowZeroWeightLeaf; 

  bool beVerbose = resetParam(param); 
  printParam(out); 

  if (!beVerbose) {
    out.deactivate(); 
  }
}