/*------------------------------------------------------------*/  
void AzsSvrg::flushDelta() 
{
  if (momentum > 0) { /* use momentum */
    check_ws("flushDelta with momentum"); 
    m_w_delta.multiply(-eta); 
    m_w_delta.add(&m_w, -eta*lam); 
    m_w_delta.add(&m_w_delta_prev, momentum); 
    m_w.add(&m_w_delta); 
    m_w_delta_prev.set(&m_w_delta); 
  }
  else { /* don't use momentum */
    ws *= (1 - lam*eta);  
    m_w.add(&m_w_delta, -eta/ws);   
  }
  m_w_delta.zeroOut(); 
  
  if (ws < 1e-10) {
    flush_ws(); 
  }
}         
Exemple #2
0
/*------------------------------------------------------------*/  
void AzpLmSgd::flushDelta(const AzpLmParam &p, const AzpLmSgd_Param &ps)
{
  if (p.dont_update()) return; 
  if (grad_num <= 0) return; 

  double etab = (ps.etab_coeff == 1) ? ps.eta : ps.eta*ps.etab_coeff; 
  if (ps.do_fast_flush && ps.momentum > 0) {  
    check_ws("flushDelta with momentum (fast_flush)"); 

    double mm = MAX(0,ps.momentum);  
    if (p.reg_L2 > 0 && p.reg_L2init <= 0) {
      m_w_dlt.add(mm, &m_w_grad, -ps.eta/(double)grad_num, &m_w, -ps.eta*p.reg_L2);      
    }
    else {     
      m_w_dlt.add(mm, &m_w_grad, -ps.eta/(double)grad_num); 
      if (!p.no_regadd()) {
        add_reg_grad(p, ps.eta, &m_w, &m_w_dlt, &m_w_init); 
      }   
    } 
    m_w.add(&m_w_dlt); 

    v_i_dlt.add(mm, &v_i_grad, -etab/(double)grad_num);  
    if (p.do_reg_intercept && !p.no_regadd()) {
      add_reg_grad(p, etab, &v_i, &v_i_dlt, &v_i_init);  /* regularization */      
    }    
    v_i.add(&v_i_dlt); 
    do_gradpart = false; 
  }
  else if (ps.momentum > 0 || p.reg_L2init > 0) { /* use momentum; slower; keeping this for compatibility */
    check_ws("flushDelta with momentum"); 
    double mm = MAX(0,ps.momentum); 
    m_w_grad.multiply(-ps.eta/(double)grad_num); 
    add_reg_grad(p, ps.eta, &m_w, &m_w_grad, &m_w_init);     /* regularization */
    m_w_grad.add(&m_w_dlt, mm); 
    m_w.add(&m_w_grad); 
    m_w_dlt.set_chk(&m_w_grad);  
    
    v_i_grad.multiply(-etab/(double)grad_num); 
    if (p.do_reg_intercept) {
      add_reg_grad(p, etab, &v_i, &v_i_grad, &v_i_init);  /* regularization */      
    }
    v_i_grad.add(&v_i_dlt, mm); 
    v_i.add(&v_i_grad); 
    v_i_dlt.set_chk(&v_i_grad); 
    do_gradpart = false; 
  }
  else { /* don't use momentum */
    regularize(p, ps.eta, etab); 
    if (doing_partial() && do_gradpart) {
      m_w.add_s2d(&m_w_grad, ia_p2w.point(), ia_p2w.size(), -ps.eta/(double)grad_num/ws); 
      v_i.add_s2d(&v_i_grad, ia_p2w.point(), ia_p2w.size(), -etab/(double)grad_num);      
    }
    else {
      m_w.add(&m_w_grad, -ps.eta/(double)grad_num/ws); 
      v_i.add(&v_i_grad, -etab/(double)grad_num); 
    }
    do_gradpart = doing_partial(); 
  }
  
  if (p.reg_L2const > 0) do_l2const(p); 
  grad_num = 0; 
  if (ws < 1e-4) flush_ws(); 
}