inline void RMSPropOptimizer::CreateState_(int index, NDArray weight) { n_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *n_[index] = 0; g_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *g_[index] = 0; delta_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *delta_[index] = 0; }
inline void SGDOptimizer::CreateState_(int index, NDArray weight) { if (params_.count("momentum") == 0) { states_[index] = nullptr; } else { states_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *states_[index] = 0; } }
inline void AdaDeltaOptimizer::CreateState_(int index, NDArray weight) { acc_g_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *acc_g_[index] = 0; acc_delta_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *acc_delta_[index] = 0; }
inline void AdaGradOptimizer::CreateState_(int index, NDArray weight) { history_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *history_[index] = 0; }
inline void AdamOptimizer::CreateState_(int index, NDArray weight) { mean_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *mean_[index] = 0; var_[index] = new NDArray(weight.GetShape(), weight.GetContext()); *var_[index] = 0; }