inline void SGDOptimizer::Update(int index, NDArray weight, NDArray grad) { if (states_.count(index) == 0) { CreateState_(index, weight); } params_["lr"] = std::to_string(GetLR_(index)); params_["wd"] = std::to_string(GetWD_(index)); UpdateCount_(index); auto keys = GetParamKeys_(); auto values = GetParamValues_(); CHECK_EQ(keys.size(), values.size()); NDArrayHandle inputs[3]; inputs[0] = weight.GetHandle(); inputs[1] = grad.GetHandle(); int num_outputs = 1; NDArrayHandle output = weight.GetHandle(); NDArrayHandle *outputs = &output; if (states_[index] == nullptr) { MXImperativeInvoke(update_handle_, 2, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data()); } else { inputs[2] = states_[index]->GetHandle(); MXImperativeInvoke(mom_update_handle_, 3, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data()); } }
inline void RMSPropOptimizer::Update(int index, NDArray weight, NDArray grad) { if (n_.count(index) == 0) { CreateState_(index, weight); } params_["lr"] = std::to_string(GetLR_(index)); params_["wd"] = std::to_string(GetWD_(index)); UpdateCount_(index); auto keys = GetParamKeys_(); auto values = GetParamValues_(); CHECK_EQ(keys.size(), values.size()); NDArrayHandle inputs[5]; inputs[0] = weight.GetHandle(); inputs[1] = grad.GetHandle(); inputs[2] = n_[index]->GetHandle(); inputs[3] = g_[index]->GetHandle(); inputs[4] = delta_[index]->GetHandle(); int num_outputs = 1; NDArrayHandle output = weight.GetHandle(); NDArrayHandle *outputs = &output; MXImperativeInvoke(alex_update_handle_, 5, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data()); }
inline void Operator::Invoke(std::vector<NDArray> &outputs) { if (input_keys_.size() > 0) { CHECK_EQ(input_keys_.size(), input_ndarrays_.size()); } std::vector<const char *> input_keys; std::vector<const char *> param_keys; std::vector<const char *> param_values; for (auto &data : params_) { param_keys.push_back(data.first.c_str()); param_values.push_back(data.second.c_str()); } int num_inputs = input_ndarrays_.size(); int num_outputs = outputs.size(); std::vector<NDArrayHandle> output_handles; std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_handles), [](NDArray& a) { return a.GetHandle(); }); NDArrayHandle *outputs_receiver = nullptr; if (num_outputs > 0) { outputs_receiver = output_handles.data(); } MXImperativeInvoke(handle_, num_inputs, input_ndarrays_.data(), &num_outputs, &outputs_receiver, param_keys.size(), param_keys.data(), param_values.data()); if (outputs.size() > 0) return; std::transform(outputs_receiver, outputs_receiver+num_outputs, std::back_inserter(outputs), [](const NDArrayHandle& handle) { return NDArray(handle); }); }
inline void AdamOptimizer::Update(int index, NDArray weight, NDArray grad) { if (mean_.count(index) == 0) { CreateState_(index, weight); } params_["lr"] = std::to_string(GetLR_(index)); params_["wd"] = std::to_string(GetWD_(index)); UpdateCount_(index); auto keys = GetParamKeys_(); auto values = GetParamValues_(); CHECK_EQ(keys.size(), values.size()); float lr = std::stof(params_["lr"]); float wd = std::stof(params_["wd"]); float b1 = std::stof(params_["beta1"]); float b2 = std::stof(params_["beta2"]); float t = count_[index]; float coef1 = 1.0f - std::pow(b1, t); float coef2 = 1.0f - std::pow(b2, t); lr *= std::sqrt(coef2) / coef1; NDArrayHandle inputs[4]; inputs[0] = weight.GetHandle(); inputs[1] = grad.GetHandle(); int num_outputs = 1; NDArrayHandle output = weight.GetHandle(); NDArrayHandle *outputs = &output; inputs[2] = mean_[index]->GetHandle(); inputs[3] = var_[index]->GetHandle(); MXImperativeInvoke(update_handle_, 4, inputs, &num_outputs, &outputs, keys.size(), keys.data(), values.data()); }