Exemple #1
0
inline void SGDOptimizer::Update(int index, NDArray weight, NDArray grad) {
  if (states_.count(index) == 0) {
    CreateState_(index, weight);
  }

  params_["lr"] = std::to_string(GetLR_(index));
  params_["wd"] = std::to_string(GetWD_(index));
  UpdateCount_(index);
  auto keys = GetParamKeys_();
  auto values = GetParamValues_();
  CHECK_EQ(keys.size(), values.size());

  NDArrayHandle inputs[3];
  inputs[0] = weight.GetHandle();
  inputs[1] = grad.GetHandle();

  int num_outputs = 1;
  NDArrayHandle output = weight.GetHandle();
  NDArrayHandle *outputs = &output;

  if (states_[index] == nullptr) {
    MXImperativeInvoke(update_handle_, 2, inputs,
        &num_outputs, &outputs,
        keys.size(), keys.data(), values.data());
  } else {
    inputs[2] = states_[index]->GetHandle();
    MXImperativeInvoke(mom_update_handle_, 3, inputs,
        &num_outputs, &outputs,
        keys.size(), keys.data(), values.data());
  }
}
Exemple #2
0
inline void RMSPropOptimizer::Update(int index, NDArray weight, NDArray grad) {
  if (n_.count(index) == 0) {
    CreateState_(index, weight);
  }

  params_["lr"] = std::to_string(GetLR_(index));
  params_["wd"] = std::to_string(GetWD_(index));
  UpdateCount_(index);
  auto keys = GetParamKeys_();
  auto values = GetParamValues_();
  CHECK_EQ(keys.size(), values.size());

  NDArrayHandle inputs[5];
  inputs[0] = weight.GetHandle();
  inputs[1] = grad.GetHandle();
  inputs[2] = n_[index]->GetHandle();
  inputs[3] = g_[index]->GetHandle();
  inputs[4] = delta_[index]->GetHandle();

  int num_outputs = 1;
  NDArrayHandle output = weight.GetHandle();
  NDArrayHandle *outputs = &output;

  MXImperativeInvoke(alex_update_handle_, 5, inputs,
      &num_outputs, &outputs,
      keys.size(), keys.data(), values.data());
}
Exemple #3
0
inline void Operator::Invoke(std::vector<NDArray> &outputs) {
  if (input_keys_.size() > 0) {
    CHECK_EQ(input_keys_.size(), input_ndarrays_.size());
  }

  std::vector<const char *> input_keys;
  std::vector<const char *> param_keys;
  std::vector<const char *> param_values;

  for (auto &data : params_) {
    param_keys.push_back(data.first.c_str());
    param_values.push_back(data.second.c_str());
  }

  int num_inputs = input_ndarrays_.size();
  int num_outputs = outputs.size();
  std::vector<NDArrayHandle> output_handles;
  std::transform(outputs.begin(), outputs.end(),
      std::back_inserter(output_handles), [](NDArray& a) {
        return a.GetHandle();
      });

  NDArrayHandle *outputs_receiver = nullptr;
  if (num_outputs > 0) {
    outputs_receiver = output_handles.data();
  }

  MXImperativeInvoke(handle_, num_inputs, input_ndarrays_.data(),
      &num_outputs, &outputs_receiver,
      param_keys.size(), param_keys.data(), param_values.data());

  if (outputs.size() > 0)
    return;

  std::transform(outputs_receiver, outputs_receiver+num_outputs,
      std::back_inserter(outputs), [](const NDArrayHandle& handle) {
        return NDArray(handle);
      });
}
Exemple #4
0
inline void AdamOptimizer::Update(int index, NDArray weight, NDArray grad) {
  if (mean_.count(index) == 0) {
    CreateState_(index, weight);
  }

  params_["lr"] = std::to_string(GetLR_(index));
  params_["wd"] = std::to_string(GetWD_(index));
  UpdateCount_(index);
  auto keys = GetParamKeys_();
  auto values = GetParamValues_();
  CHECK_EQ(keys.size(), values.size());

  float lr = std::stof(params_["lr"]);
  float wd = std::stof(params_["wd"]);
  float b1 = std::stof(params_["beta1"]);
  float b2 = std::stof(params_["beta2"]);
  float t = count_[index];
  float coef1 = 1.0f - std::pow(b1, t);
  float coef2 = 1.0f - std::pow(b2, t);
  lr *= std::sqrt(coef2) / coef1;

  NDArrayHandle inputs[4];
  inputs[0] = weight.GetHandle();
  inputs[1] = grad.GetHandle();

  int num_outputs = 1;
  NDArrayHandle output = weight.GetHandle();
  NDArrayHandle *outputs = &output;

  inputs[2] = mean_[index]->GetHandle();
  inputs[3] = var_[index]->GetHandle();

  MXImperativeInvoke(update_handle_, 4, inputs,
    &num_outputs, &outputs,
    keys.size(), keys.data(), values.data());
}