示例#1
0
  void flushDelta(const AzpLmParam &p, const AzpLmAdaD_Param &pa) {
    if (p.dont_update()) return; 
    if (grad_num <= 0) return; 
    check_ws("AzpLmAdaD::flushDelta"); 

    bool do_reg = true; 
    adad_update(&m_w, &m_w_grad, &m_w_g2avg, &m_w_dlt, &m_w_init, 
                p, pa, do_reg); 
    do_reg = p.do_reg_intercept; 
    adad_update(&v_i, &v_i_grad, &v_i_g2avg, &v_i_dlt, &v_i_init, 
                p, pa, do_reg); 

    if (p.reg_L2const > 0) do_l2const(p); 
    grad_num = 0; 
  }
示例#2
0
/*------------------------------------------------------------*/  
void AzsSvrg::flushDelta() 
{
  if (momentum > 0) { /* use momentum */
    check_ws("flushDelta with momentum"); 
    m_w_delta.multiply(-eta); 
    m_w_delta.add(&m_w, -eta*lam); 
    m_w_delta.add(&m_w_delta_prev, momentum); 
    m_w.add(&m_w_delta); 
    m_w_delta_prev.set(&m_w_delta); 
  }
  else { /* don't use momentum */
    ws *= (1 - lam*eta);  
    m_w.add(&m_w_delta, -eta/ws);   
  }
  m_w_delta.zeroOut(); 
  
  if (ws < 1e-10) {
    flush_ws(); 
  }
}         
示例#3
0
/*------------------------------------------------------------*/ 
void AzpLmSgd::regularize(const AzpLmParam &p, double eta, double etab) 
{
  AzX::no_support((p.reg_L2init > 0), "AzpLmSgd::regularize", "reg_L2init in this configuration"); 

  if (p.reg_L2 == 0) {}
  else if (p.reg_L2 > 0) {
    ws *= (1 - p.reg_L2 * eta); 
    if (p.do_reg_intercept) {
      v_i.multiply(1 - p.reg_L2*etab); 
    }
  }
  else if (p.reg_L1L2 > 0) {
    check_ws("AzpLmSgd::regularize(L1L2)"); 
    AzPmatApp app; 
    AzPmat m; 
    app.l1l2deriv(&m_w, &m, (AzFloat)p.reg_L1L2_delta); 
    m_w.add(&m, -p.reg_L1L2*eta); 
    if (p.do_reg_intercept) {
      app.l1l2deriv(&v_i, &m, (AzFloat)p.reg_L1L2_delta); 
      v_i.add(&m, -p.reg_L1L2*etab); 
    }   
  }
}
示例#4
0
/*------------------------------------------------------------*/  
void AzpLmSgd::flushDelta(const AzpLmParam &p, const AzpLmSgd_Param &ps)
{
  if (p.dont_update()) return; 
  if (grad_num <= 0) return; 

  double etab = (ps.etab_coeff == 1) ? ps.eta : ps.eta*ps.etab_coeff; 
  if (ps.do_fast_flush && ps.momentum > 0) {  
    check_ws("flushDelta with momentum (fast_flush)"); 

    double mm = MAX(0,ps.momentum);  
    if (p.reg_L2 > 0 && p.reg_L2init <= 0) {
      m_w_dlt.add(mm, &m_w_grad, -ps.eta/(double)grad_num, &m_w, -ps.eta*p.reg_L2);      
    }
    else {     
      m_w_dlt.add(mm, &m_w_grad, -ps.eta/(double)grad_num); 
      if (!p.no_regadd()) {
        add_reg_grad(p, ps.eta, &m_w, &m_w_dlt, &m_w_init); 
      }   
    } 
    m_w.add(&m_w_dlt); 

    v_i_dlt.add(mm, &v_i_grad, -etab/(double)grad_num);  
    if (p.do_reg_intercept && !p.no_regadd()) {
      add_reg_grad(p, etab, &v_i, &v_i_dlt, &v_i_init);  /* regularization */      
    }    
    v_i.add(&v_i_dlt); 
    do_gradpart = false; 
  }
  else if (ps.momentum > 0 || p.reg_L2init > 0) { /* use momentum; slower; keeping this for compatibility */
    check_ws("flushDelta with momentum"); 
    double mm = MAX(0,ps.momentum); 
    m_w_grad.multiply(-ps.eta/(double)grad_num); 
    add_reg_grad(p, ps.eta, &m_w, &m_w_grad, &m_w_init);     /* regularization */
    m_w_grad.add(&m_w_dlt, mm); 
    m_w.add(&m_w_grad); 
    m_w_dlt.set_chk(&m_w_grad);  
    
    v_i_grad.multiply(-etab/(double)grad_num); 
    if (p.do_reg_intercept) {
      add_reg_grad(p, etab, &v_i, &v_i_grad, &v_i_init);  /* regularization */      
    }
    v_i_grad.add(&v_i_dlt, mm); 
    v_i.add(&v_i_grad); 
    v_i_dlt.set_chk(&v_i_grad); 
    do_gradpart = false; 
  }
  else { /* don't use momentum */
    regularize(p, ps.eta, etab); 
    if (doing_partial() && do_gradpart) {
      m_w.add_s2d(&m_w_grad, ia_p2w.point(), ia_p2w.size(), -ps.eta/(double)grad_num/ws); 
      v_i.add_s2d(&v_i_grad, ia_p2w.point(), ia_p2w.size(), -etab/(double)grad_num);      
    }
    else {
      m_w.add(&m_w_grad, -ps.eta/(double)grad_num/ws); 
      v_i.add(&v_i_grad, -etab/(double)grad_num); 
    }
    do_gradpart = doing_partial(); 
  }
  
  if (p.reg_L2const > 0) do_l2const(p); 
  grad_num = 0; 
  if (ws < 1e-4) flush_ws(); 
}