void flushDelta(const AzpLmParam &p, const AzpLmAdaD_Param &pa) { if (p.dont_update()) return; if (grad_num <= 0) return; check_ws("AzpLmAdaD::flushDelta"); bool do_reg = true; adad_update(&m_w, &m_w_grad, &m_w_g2avg, &m_w_dlt, &m_w_init, p, pa, do_reg); do_reg = p.do_reg_intercept; adad_update(&v_i, &v_i_grad, &v_i_g2avg, &v_i_dlt, &v_i_init, p, pa, do_reg); if (p.reg_L2const > 0) do_l2const(p); grad_num = 0; }
/*------------------------------------------------------------*/ void AzsSvrg::flushDelta() { if (momentum > 0) { /* use momentum */ check_ws("flushDelta with momentum"); m_w_delta.multiply(-eta); m_w_delta.add(&m_w, -eta*lam); m_w_delta.add(&m_w_delta_prev, momentum); m_w.add(&m_w_delta); m_w_delta_prev.set(&m_w_delta); } else { /* don't use momentum */ ws *= (1 - lam*eta); m_w.add(&m_w_delta, -eta/ws); } m_w_delta.zeroOut(); if (ws < 1e-10) { flush_ws(); } }
/*------------------------------------------------------------*/ void AzpLmSgd::regularize(const AzpLmParam &p, double eta, double etab) { AzX::no_support((p.reg_L2init > 0), "AzpLmSgd::regularize", "reg_L2init in this configuration"); if (p.reg_L2 == 0) {} else if (p.reg_L2 > 0) { ws *= (1 - p.reg_L2 * eta); if (p.do_reg_intercept) { v_i.multiply(1 - p.reg_L2*etab); } } else if (p.reg_L1L2 > 0) { check_ws("AzpLmSgd::regularize(L1L2)"); AzPmatApp app; AzPmat m; app.l1l2deriv(&m_w, &m, (AzFloat)p.reg_L1L2_delta); m_w.add(&m, -p.reg_L1L2*eta); if (p.do_reg_intercept) { app.l1l2deriv(&v_i, &m, (AzFloat)p.reg_L1L2_delta); v_i.add(&m, -p.reg_L1L2*etab); } } }
/*------------------------------------------------------------*/ void AzpLmSgd::flushDelta(const AzpLmParam &p, const AzpLmSgd_Param &ps) { if (p.dont_update()) return; if (grad_num <= 0) return; double etab = (ps.etab_coeff == 1) ? ps.eta : ps.eta*ps.etab_coeff; if (ps.do_fast_flush && ps.momentum > 0) { check_ws("flushDelta with momentum (fast_flush)"); double mm = MAX(0,ps.momentum); if (p.reg_L2 > 0 && p.reg_L2init <= 0) { m_w_dlt.add(mm, &m_w_grad, -ps.eta/(double)grad_num, &m_w, -ps.eta*p.reg_L2); } else { m_w_dlt.add(mm, &m_w_grad, -ps.eta/(double)grad_num); if (!p.no_regadd()) { add_reg_grad(p, ps.eta, &m_w, &m_w_dlt, &m_w_init); } } m_w.add(&m_w_dlt); v_i_dlt.add(mm, &v_i_grad, -etab/(double)grad_num); if (p.do_reg_intercept && !p.no_regadd()) { add_reg_grad(p, etab, &v_i, &v_i_dlt, &v_i_init); /* regularization */ } v_i.add(&v_i_dlt); do_gradpart = false; } else if (ps.momentum > 0 || p.reg_L2init > 0) { /* use momentum; slower; keeping this for compatibility */ check_ws("flushDelta with momentum"); double mm = MAX(0,ps.momentum); m_w_grad.multiply(-ps.eta/(double)grad_num); add_reg_grad(p, ps.eta, &m_w, &m_w_grad, &m_w_init); /* regularization */ m_w_grad.add(&m_w_dlt, mm); m_w.add(&m_w_grad); m_w_dlt.set_chk(&m_w_grad); v_i_grad.multiply(-etab/(double)grad_num); if (p.do_reg_intercept) { add_reg_grad(p, etab, &v_i, &v_i_grad, &v_i_init); /* regularization */ } v_i_grad.add(&v_i_dlt, mm); v_i.add(&v_i_grad); v_i_dlt.set_chk(&v_i_grad); do_gradpart = false; } else { /* don't use momentum */ regularize(p, ps.eta, etab); if (doing_partial() && do_gradpart) { m_w.add_s2d(&m_w_grad, ia_p2w.point(), ia_p2w.size(), -ps.eta/(double)grad_num/ws); v_i.add_s2d(&v_i_grad, ia_p2w.point(), ia_p2w.size(), -etab/(double)grad_num); } else { m_w.add(&m_w_grad, -ps.eta/(double)grad_num/ws); v_i.add(&v_i_grad, -etab/(double)grad_num); } do_gradpart = doing_partial(); } if (p.reg_L2const > 0) do_l2const(p); grad_num = 0; if (ws < 1e-4) flush_ws(); }