void resetParam(AzParam &azp, const char *pfx, bool is_warmstart=false) { azp.reset_prefix(pfx); azp.vFloat(kw_rho, &rho); azp.vFloat(kw_eps, &eps); coeff = 1; azp.reset_prefix(); }
/*------------------------------------------------------------------*/ bool AzTrTreeFeat::resetParam(AzParam &p) { p.swOn(&doCountRules, kw_doCountRules); p.swOn(&doCheckConsistency, kw_doCheckConsistency); bool beVerbose = false; p.swOn(&beVerbose, kw_opt_beVerbose); return beVerbose; }
/*------------------------------------------------------------*/ int AzpCNet3_multi::getParam_dsno(AzParam &azp, const char *pfx, int dflt_dsno) const { azp.reset_prefix(pfx); int dsno = dflt_dsno; azp.vInt(kw_dsno, &dsno); azp.reset_prefix(); if (dsno < 0) dsno = dflt_dsno; /* if omitted, use dataset#0 */ return dsno; }
virtual void resetParam(AzParam &azp, const char *pfx, bool is_warmstart) { azp.reset_prefix(pfx); if (!is_warmstart) { azp.vStr(kw_activ_typ, &s_activ_typ); if (s_activ_typ.length() > 0) typ = *s_activ_typ.point(); azp.vFloat(kw_trunc, &trunc); } azp.swOn(&do_stat, kw_do_stat); azp.reset_prefix(); }
/*------------------------------------------------*/ virtual void resetParam_data(AzParam &azp, bool is_training) { const char *eyec = "AzpData_img::resetParam_data"; azp.vInt(kw_channels, &channels); azp.vInt(kw_sz1, &sz1); azp.vInt(kw_sz2, &sz2); azp.vFloat(kw_data_scale, &data_scale); AzXi::throw_if_nonpositive(channels, eyec, kw_channels); AzXi::throw_if_nonpositive(sz1, eyec, kw_sz1); AzXi::throw_if_nonpositive(sz2, eyec, kw_sz2); AzXi::check_input(s_x_ext, &sp_x_ext, eyec, kw_x_ext); if (s_y_ext.length() > 0) AzXi::check_input(s_y_ext, &sp_y_ext, eyec, kw_y_ext); }
/**@ * reset the paramenters: max_depth, max_leaf_num, min_size */ void AzRgfTree::resetParam(AzParam &p) { p.vInt(kw_max_depth, &max_depth); p.vInt(kw_min_size, &min_size); p.vInt(kw_max_leaf_num, &max_leaf_num); p.swOn(&doUseInternalNodes, kw_doUseInternalNodes); p.swOn(&beVerbose, kw_tree_beVerbose); if (!beVerbose) { my_dmp_out.deactivate(); out.deactivate(); } adjustParam(); }
void warm_start(const AzTreeEnsemble *inp_ens, const AzDataForTrTree *data, AzParam ¶m, const AzBytArr *s_temp_prefix, const AzOut &out, int max_t_num, int search_t_num, AzDvect *v_p, /* inout */ const AzIntArr *inp_ia_tr_dx=NULL) { const char *eyec = "AzTrTreeEnsemble::warmup"; if (max_t_num < inp_ens->size()) { throw new AzException(eyec, "maximum #tree is less than the #tree we already have"); } reset(); a_tree.alloc(&t, max_t_num, "AzTrTreeEnsemble::warmup"); t_num = inp_ens->size(); const_val = inp_ens->constant(); org_dim = inp_ens->orgdim(); if (org_dim > 0 && org_dim != data->featNum()) { throw new AzException(AzInputError, eyec, "feature dimensionality mismatch"); } const AzIntArr *ia_tr_dx = inp_ia_tr_dx; AzIntArr ia_temp; if (ia_tr_dx == NULL) { ia_temp.range(0, data->dataNum()); ia_tr_dx = &ia_temp; } v_p->reform(data->dataNum()); v_p->add(const_val, ia_tr_dx); T dummy_tree(param); if (dummy_tree.usingInternalNodes()) { throw new AzException(AzInputError, eyec, "warm start is not allowed with use of internal nodes"); } dummy_tree.printParam(out); temp_files.reset(&dummy_tree, data->dataNum(), s_temp_prefix); s_param.reset(param.c_str()); dt_param = s_param.c_str(); AzParam p(dt_param, false); int tx; for (tx = 0; tx < t_num; ++tx) { t[tx] = new T(p); t[tx]->forStoringDataIndexes(temp_files.point_file()); if (search_t_num > 0 && tx < t_num-search_t_num) { t[tx]->quick_warmup(inp_ens->tree(tx), data, v_p, ia_tr_dx); } else { t[tx]->warmup(inp_ens->tree(tx), data, v_p, ia_tr_dx); } } }
/*--- for parameters ---*/ virtual void resetParam(AzParam &p) { p.vStr(kw_dataproc, &s_dataproc); dataproc = dataproc_Auto; if (s_dataproc.length() <= 0 || s_dataproc.compare("Auto") == 0); else if (s_dataproc.compare("Sparse") == 0) dataproc = dataproc_Sparse; else if (s_dataproc.compare("Dense") == 0) dataproc = dataproc_Dense; else { throw new AzException(AzInputNotValid, kw_dataproc, "must be either \"Auto\", \"Sparse\", or \"Dense\"."); } }
/*------------------------------------------------------------------*/ void AzOptOnTree::resetParam(AzParam &p) { p.vFloat(kw_lambda, &lambda); p.vFloat(kw_sigma, &sigma); p.vInt(kw_max_ite_num, &max_ite_num); p.vFloat(kw_eta, &eta); p.vFloat(kw_exit_delta, &exit_delta); p.vFloat(kw_max_delta, &max_delta); p.swOn(&doUseAvg, kw_doUseAvg); p.swOff(&doIntercept, kw_not_doIntercept); /* useless but keep this for compatibility */ p.swOn(&doIntercept, kw_doIntercept); if (max_ite_num <= 0) { max_ite_num = max_ite_num_dflt_oth; if (AzLoss::isExpoFamily(loss_type)) { max_ite_num = max_ite_num_dflt_expo; } } }
void resetParam(AzParam &azp, const char *pfx, bool is_warmstart) { azp.reset_prefix(pfx); if (!is_warmstart) { azp.vStr(kw_pl_type, &s_pl_type); if (s_pl_type.length() > 0) ptyp = *s_pl_type.point(); else ptyp = AzpPoolingDflt_None; azp.vInt(kw_pl_num, &pl_num); azp.vInt(kw_pl_sz, &pl_sz); azp.vInt(kw_pl_step, &pl_step); azp.swOff(&do_pl_simple_grid, kw_no_pl_simple_grid); } azp.reset_prefix(); }
inline void cold_start( AzParam ¶m, const AzBytArr *s_temp_prefix, /* may be NULL */ int data_num, /* to estimate the data size for temp */ const AzOut &out, int tree_num_max, int inp_org_dim) { reset(); T dummy_tree(param); dummy_tree.printParam(out); s_param.reset(param.c_str()); dt_param = s_param.c_str(); alloc(tree_num_max, "AzTrTreeEnsemble::reset"); //@ allocate forest space org_dim = inp_org_dim; temp_files.reset(&dummy_tree, data_num, s_temp_prefix); //@ estimate the data size for temp and do something? }
/* output: sp_conn */ void AzpCNet3_multi::resetParam_conn(AzParam &azp, AzIIarr &iia_conn) { const char *eyec = "AzpCNet3_multi::resetParam_conn"; iia_conn.reset(); sp_conn.size(); int top = hid_num; int ix; for (ix = 0; ; ++ix) { AzBytArr s_conn; AzBytArr s_kw(kw_conn); s_kw << ix << "="; azp.vStr(s_kw.c_str(), &s_conn); if (s_conn.length() <= 0) break; AzStrPool sp(32,32); AzTools::getStrings(s_conn.point(), s_conn.length(), conn_dlm, &sp); AzX::throw_if((sp.size() != 2), AzInputError, "Expected the format like n-m for", kw_conn); int below = parse_layerno(sp.c_str(0), hid_num); int above = parse_layerno(sp.c_str(1), hid_num); AzX::throw_if((below == top), AzInputError, eyec, "No edge is allowed to go out of the top layer"); iia_conn.put(below, above); sp_conn.put(&s_conn); } /*--- default ---*/ if (iia_conn.size() == 0) { for (int lx = 0; lx < hid_num; ++lx) { int below = lx, above = lx + 1; iia_conn.put(below, above); AzBytArr s_conn; s_conn << below << conn_dlm << above; sp_conn.put(s_conn.c_str()); } } }
/*--------------------------------------------------------*/ int AzRgforest::resetParam(AzParam &p) { const char *eyec = "AzRgforest::resetParam"; /*--- for storing data indexes in the trees to disk ---*/ /*--- this must be called before adjustTestInterval. ---*/ p.vStr(kw_temp_for_trees, &s_temp_for_trees); /*--- loss function ---*/ p.vLoss(kw_loss, &loss_type); /*--- weight optimization interval ---*/ int lnum_inc_opt = lnum_inc_opt_dflt; p.vInt(kw_lnum_inc_opt, &lnum_inc_opt); if (lnum_inc_opt <= 0) { throw new AzException(AzInputNotValid, eyec, kw_lnum_inc_opt, "must be positive"); } opt_timer.reset(lnum_inc_opt); /*--- # of trees to search ---*/ p.vInt(kw_s_tree_num, &s_tree_num); if (s_tree_num <= 0) { throw new AzException(AzInputNotValid, eyec, kw_s_tree_num, "must be positive"); } /*--- when to stop: max #leaf, max #tree ---*/ int max_tree_num = -1, max_lnum = max_lnum_dflt; p.vInt(kw_max_tree_num, &max_tree_num); p.vInt(kw_max_lnum, &max_lnum); if (max_tree_num <= 0) { if (max_lnum > 0) max_tree_num = MAX(1, max_lnum / 2); else { AzBytArr s("Specify "); s.c(kw_max_lnum); s.c(" and/or "); s.c(kw_max_tree_num); throw new AzException(AzInputMissing, eyec, s.c_str()); } } lmax_timer.reset(max_lnum); /*--- when to test: test interval ---*/ int lnum_inc_test = lnum_inc_test_dflt; p.vInt(kw_lnum_inc_test, &lnum_inc_test); if (lnum_inc_test <= 0) { throw new AzException(AzInputNotValid, eyec, kw_lnum_inc_test, "must be positive"); } lnum_inc_test = adjustTestInterval(lnum_inc_test, lnum_inc_opt); test_timer.reset(lnum_inc_test); /*--- memory handling ---*/ p.vStr(kw_mem_policy, &s_mem_policy); if (s_mem_policy.length() <= 0) beTight = false; else if (s_mem_policy.compare(mp_beTight) == 0) beTight = true; else if (s_mem_policy.compare(mp_not_beTight) == 0) beTight = false; else { AzBytArr s(kw_mem_policy); s.c(" should be either "); s.c(mp_beTight); s.c(" or "); s.c(mp_not_beTight); throw new AzException(AzInputNotValid, eyec, s.c_str()); } p.vFloat(kw_f_ratio, &f_ratio); if (f_ratio > 1) { throw new AzException(AzInputNotValid, kw_f_ratio, "must be between 0 and 1."); } int random_seed = -1; if (f_ratio > 0 && f_ratio < 1) { p.vInt(kw_random_seed, &random_seed); if (srand > 0) { srand(random_seed); } } p.swOn(&doPassiveRoot, kw_doPassiveRoot); /*--- for maintenance purposes ---*/ p.swOn(&doForceToRefreshAll, kw_doForceToRefreshAll); p.swOn(&beVerbose, kw_forest_beVerbose); /* for compatibility */ p.swOn(&beVerbose, kw_beVerbose); p.swOn(&doTime, kw_doTime); /*--- display parameters ---*/ if (!out.isNull()) { AzPrint o(out); o.ppBegin("AzRgforest", "Forest-level"); o.printLoss(kw_loss, loss_type); o.printV(kw_max_lnum, max_lnum); o.printV(kw_max_tree_num, max_tree_num); o.printV(kw_lnum_inc_opt, lnum_inc_opt); o.printV(kw_lnum_inc_test, lnum_inc_test); o.printV(kw_s_tree_num, s_tree_num); o.printSw(kw_doForceToRefreshAll, doForceToRefreshAll); o.printSw(kw_beVerbose, beVerbose); o.printSw(kw_doTime, doTime); o.printV_if_not_empty(kw_mem_policy, s_mem_policy); o.printV_if_not_empty(kw_temp_for_trees, &s_temp_for_trees); o.printV(kw_f_ratio, f_ratio); o.printV(kw_random_seed, random_seed); o.printSw(kw_doPassiveRoot, doPassiveRoot); o.ppEnd(); } if (loss_type == AzLoss_None) { throw new AzException(AzInputNotValid, eyec, kw_loss); } return max_tree_num; }
/*-------------------------------------------------------------------------*/ void resetParam(AzParam &azp) { const char *eyec = "AzPrepText2_gen_regions_parsup_Param::resetParam"; azp.vInt(kw_top_num_each, &top_num_each); azp.vInt(kw_top_num_total, &top_num_total); azp.vStr(kw_feat_fn, &s_feat_fn); azp.vStr(kw_xtyp, &s_xtyp); azp.vStr(kw_xdic_fn, &s_xdic_fn); azp.vStr(kw_inp_fn, &s_inp_fn); azp.vStr(kw_rnm, &s_rnm); azp.vStr(kw_txt_ext, &s_txt_ext); azp.vInt(kw_pch_sz, &pch_sz); azp.vInt(kw_pch_step, &pch_step); azp.vInt(kw_padding, &padding); f_pch_sz = pch_sz; f_pch_step = 1; f_padding = f_pch_sz - 1; azp.vInt(kw_f_pch_sz, &f_pch_sz); azp.vInt(kw_f_pch_step, &f_pch_step); azp.vInt(kw_f_padding, &f_padding); azp.vInt(kw_min_x, &min_x); azp.vInt(kw_min_y, &min_y); azp.vInt(kw_dist, &dist); azp.swOn(&do_lower, kw_do_lower); azp.swOn(&do_utf8dashes, kw_do_utf8dashes); azp.swOn(&do_nolr, kw_do_nolr); azp.vStr(kw_batch_id, &s_batch_id); azp.vStr(kw_x_ext, &s_x_ext); azp.vStr(kw_y_ext, &s_y_ext); azp.swOn(&do_binarize, kw_do_binarize); if (!do_binarize) { azp.vFloat(kw_scale_y, &scale_y); } azp.vFloat(kw_min_yval, &min_yval); AzXi::throw_if_empty(s_feat_fn, eyec, kw_feat_fn); AzXi::throw_if_empty(s_xtyp, eyec, kw_xtyp); AzXi::throw_if_empty(s_xdic_fn, eyec, kw_xdic_fn); AzXi::throw_if_empty(s_inp_fn, eyec, kw_inp_fn); AzXi::throw_if_empty(s_rnm, eyec, kw_rnm); AzXi::throw_if_nonpositive(pch_sz, eyec, kw_pch_sz); AzXi::throw_if_nonpositive(pch_step, eyec, kw_pch_step); AzXi::throw_if_negative(padding, eyec, kw_padding); AzXi::throw_if_nonpositive(f_pch_sz, eyec, kw_f_pch_sz); AzXi::throw_if_nonpositive(f_pch_step, eyec, kw_f_pch_step); AzXi::throw_if_negative(f_padding, eyec, kw_f_padding); AzXi::throw_if_nonpositive(dist, eyec, kw_dist); if (f_pch_sz > dist) { AzBytArr s(kw_f_pch_sz); s << " must be no greater than " << kw_dist << "."; AzX::throw_if(true, AzInputError, eyec, s.c_str()); } }
/*-------------------------------------------------------------------------*/ void resetParam(AzParam &azp) { const char *eyec = "AzPrepText2_gen_regions_unsup_Param::resetParam"; azp.vStr(kw_xtyp, &s_xtyp); azp.vStr(kw_xdic_fn, &s_xdic_fn); azp.vStr(kw_ydic_fn, &s_ydic_fn); azp.vStr(kw_inp_fn, &s_inp_fn); azp.vStr(kw_rnm, &s_rnm); azp.vStr(kw_txt_ext, &s_txt_ext); azp.vInt(kw_pch_sz, &pch_sz); azp.vInt(kw_pch_step, &pch_step); azp.vInt(kw_padding, &padding); azp.vInt(kw_dist, &dist); azp.swOn(&do_lower, kw_do_lower); azp.swOn(&do_utf8dashes, kw_do_utf8dashes); azp.swOn(&do_nolr, kw_do_nolr); azp.vStr(kw_batch_id, &s_batch_id); azp.vStr(kw_x_ext, &s_x_ext); azp.vStr(kw_y_ext, &s_y_ext); azp.swOn(&do_no_skip, kw_do_no_skip, false); if (do_no_skip) { min_x = min_y = 0; } else { azp.vInt(kw_min_x, &min_x); azp.vInt(kw_min_y, &min_y); } AzStrPool sp_typ(10,10); sp_typ.put(kw_bow, kw_seq); AzXi::check_input(s_xtyp, &sp_typ, eyec, kw_xtyp); AzXi::throw_if_empty(s_xdic_fn, eyec, kw_xdic_fn); AzXi::throw_if_empty(s_ydic_fn, eyec, kw_ydic_fn); AzXi::throw_if_empty(s_inp_fn, eyec, kw_inp_fn); AzXi::throw_if_empty(s_rnm, eyec, kw_rnm); AzXi::throw_if_nonpositive(pch_sz, eyec, kw_pch_sz); AzXi::throw_if_nonpositive(pch_step, eyec, kw_pch_step); AzXi::throw_if_negative(padding, eyec, kw_padding); AzXi::throw_if_nonpositive(dist, eyec, kw_dist); AzStrPool sp_x_ext(10,10); sp_x_ext.put(".xsmat", ".xsmatvar"); AzStrPool sp_y_ext(10,10); sp_y_ext.put(".ysmat", ".ysmatvar"); AzXi::check_input(s_x_ext, &sp_x_ext, eyec, kw_x_ext); AzXi::check_input(s_y_ext, &sp_y_ext, eyec, kw_y_ext); }
/*--------------------------------------------------------*/ void AzsSvrg::resetParam(AzParam &azp) { AzBytArr s_loss; azp.vStr(kw_loss, &s_loss); if (s_loss.length() > 0) loss_type = lossType(s_loss.c_str()); azp.vInt(kw_svrg_interval, &svrg_interval); azp.vInt(kw_sgd_ite, &sgd_ite); azp.vInt(kw_test_interval, &test_interval); azp.vInt(kw_ite_num, &ite_num); azp.vInt(kw_rseed, &rseed); azp.vFloat(kw_eta, &eta); azp.vFloat(kw_momentum, &momentum); azp.vFloat(kw_lam, &lam); azp.swOn(&do_compact, kw_do_compact); azp.swOn(&do_show_loss, kw_do_show_loss); azp.swOn(&do_show_timing, kw_do_show_timing); azp.swOn(&with_replacement, kw_with_replacement); azp.vStr(kw_pred_fn, &s_pred_fn); }