Exemplo n.º 1
0
inline void RMSPropOptimizer::CreateState_(int index, NDArray weight) {
  n_[index] = new NDArray(weight.GetShape(), weight.GetContext());
  *n_[index] = 0;
  g_[index] = new NDArray(weight.GetShape(), weight.GetContext());
  *g_[index] = 0;
  delta_[index] = new NDArray(weight.GetShape(), weight.GetContext());
  *delta_[index] = 0;
}
Exemplo n.º 2
0
inline void SGDOptimizer::CreateState_(int index, NDArray weight) {
  if (params_.count("momentum") == 0) {
    states_[index] = nullptr;
  } else {
    states_[index] = new NDArray(weight.GetShape(), weight.GetContext());
    *states_[index] = 0;
  }
}
Exemplo n.º 3
0
inline void AdaDeltaOptimizer::CreateState_(int index, NDArray weight) {
  acc_g_[index] = new NDArray(weight.GetShape(), weight.GetContext());
  *acc_g_[index] = 0;
  acc_delta_[index] = new NDArray(weight.GetShape(), weight.GetContext());
  *acc_delta_[index] = 0;
}
Exemplo n.º 4
0
inline void AdaGradOptimizer::CreateState_(int index, NDArray weight) {
  history_[index] = new NDArray(weight.GetShape(), weight.GetContext());
  *history_[index] = 0;
}
Exemplo n.º 5
0
inline void AdamOptimizer::CreateState_(int index, NDArray weight) {
  mean_[index] = new NDArray(weight.GetShape(), weight.GetContext());
  *mean_[index] = 0;
  var_[index] = new NDArray(weight.GetShape(), weight.GetContext());
  *var_[index] = 0;
}