inline void SGDOptimizer::Update(int index, NDArray weight, NDArray grad) { if (states_.count(index) == 0) { CreateState_(index, weight); } params_["lr"] = std::to_string(GetLR_(index)); params_["wd"] = std::to_string(GetWD_(index)); UpdateCount_(index); auto keys = GetParamKeys_(); auto values = GetParamValues_(); CHECK_EQ(keys.size(), values.size()); NDArrayHandle inputs[3]; inputs[0] = weight.GetHandle(); inputs[1] = grad.GetHandle(); int num_outputs = 1; NDArrayHandle output = weight.GetHandle(); NDArrayHandle *outputs = &output; if (states_[index] == nullptr) { MXImperativeInvoke(update_handle_, 2, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data()); } else { inputs[2] = states_[index]->GetHandle(); MXImperativeInvoke(mom_update_handle_, 3, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data()); } }
inline void RMSPropOptimizer::Update(int index, NDArray weight, NDArray grad) { if (n_.count(index) == 0) { CreateState_(index, weight); } params_["lr"] = std::to_string(GetLR_(index)); params_["wd"] = std::to_string(GetWD_(index)); UpdateCount_(index); auto keys = GetParamKeys_(); auto values = GetParamValues_(); CHECK_EQ(keys.size(), values.size()); NDArrayHandle inputs[5]; inputs[0] = weight.GetHandle(); inputs[1] = grad.GetHandle(); inputs[2] = n_[index]->GetHandle(); inputs[3] = g_[index]->GetHandle(); inputs[4] = delta_[index]->GetHandle(); int num_outputs = 1; NDArrayHandle output = weight.GetHandle(); NDArrayHandle *outputs = &output; MXImperativeInvoke(alex_update_handle_, 5, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data()); }
void Optimizer::Update(int index, NDArray weight, NDArray grad) { if (!init_) { std::vector<const char *> param_keys; std::vector<const char *> param_values; for (const auto &k_v : params_) { param_keys.push_back(k_v.first.c_str()); param_values.push_back(k_v.second.c_str()); } MXOptimizerCreateOptimizer(creator_, params_.size(), param_keys.data(), param_values.data(), &handle_); init_ = true; } MXOptimizerUpdate(handle_, index, weight.GetHandle(), grad.GetHandle(), learning_rate_, weight_decay_); }
inline void AdamOptimizer::Update(int index, NDArray weight, NDArray grad) { if (mean_.count(index) == 0) { CreateState_(index, weight); } params_["lr"] = std::to_string(GetLR_(index)); params_["wd"] = std::to_string(GetWD_(index)); UpdateCount_(index); auto keys = GetParamKeys_(); auto values = GetParamValues_(); CHECK_EQ(keys.size(), values.size()); float lr = std::stof(params_["lr"]); float wd = std::stof(params_["wd"]); float b1 = std::stof(params_["beta1"]); float b2 = std::stof(params_["beta2"]); float t = count_[index]; float coef1 = 1.0f - std::pow(b1, t); float coef2 = 1.0f - std::pow(b2, t); lr *= std::sqrt(coef2) / coef1; NDArrayHandle inputs[4]; inputs[0] = weight.GetHandle(); inputs[1] = grad.GetHandle(); int num_outputs = 1; NDArrayHandle output = weight.GetHandle(); NDArrayHandle *outputs = &output; inputs[2] = mean_[index]->GetHandle(); inputs[3] = var_[index]->GetHandle(); MXImperativeInvoke(update_handle_, 4, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data()); }