Пример #1
0
void NHERD::train(const sfv_t& sfv, const string& label){
  string incorrect_label;
  float variance = 0.f;
  float margin = - calc_margin_and_variance(sfv, label, incorrect_label, variance);
  if (margin >= 1.f) {
    return;
  }
  update(sfv, margin, variance, label, incorrect_label);
}
Пример #2
0
void normal_herd::train(const common::sfv_t& sfv, const string& label) {
  string incorrect_label;
  float variance = 0.f;
  float margin = -calc_margin_and_variance(sfv, label, incorrect_label,
                                           variance);
  if (margin >= 1.f) {
    get_storage()->register_label(label);
    return;
  }
  update(sfv, margin, variance, label, incorrect_label);
}
Пример #3
0
void AROW::train(const sfv_t& sfv, const string& label){
  string incorrect_label;
  float variance = 0.f;
  float margin = - calc_margin_and_variance(sfv, label, incorrect_label, variance);
   if (margin >= 1.f) {
    return;
  }
  float beta = 1.f / (variance + C_);
  float alpha = (1.f - margin) * beta; // max(0, 1-margin) = 1-margin 
  update(sfv, alpha, beta, label, incorrect_label);
}
Пример #4
0
void CW::train(const sfv_t& sfv, const string& label){
  const float C = config.C;
  string incorrect_label;
  float variance = 0.f;
  float margin = - calc_margin_and_variance(sfv, label, incorrect_label, variance);
  float b = 1.f + 2 * C * margin;
  float gamma = - b + sqrt(b * b - 8 * C * (margin - C * variance));

  if (gamma <= 0.f){
    return;
  }
  gamma /= 4 * C * variance;
  update(sfv, gamma, label, incorrect_label);
}
Пример #5
0
void arow::train(const common::sfv_t& sfv, const string& label) {
  string incorrect_label;
  float variance = 0.f;
  float margin = -calc_margin_and_variance(sfv, label, incorrect_label,
                                           variance);
  if (margin >= 1.f) {
    get_storage()->register_label(label);
    return;
  }

  float beta = 1.f / (variance + 1.f / config_.C);
  float alpha = (1.f - margin) * beta;  // max(0, 1 - margin) = 1 - margin
  update(sfv, alpha, beta, label, incorrect_label);
}
void confidence_weighted::train(const common::sfv_t& sfv, const string& label) {
  check_touchable(label);

  const float C = config_.regularization_weight;
  string incorrect_label;
  float variance = 0.f;
  float margin = -calc_margin_and_variance(sfv, label, incorrect_label,
                                           variance);
  float b = 1.f + 2 * C * margin;
  float gamma = -b + std::sqrt(b * b - 8 * C * (margin - C * variance));

  if (gamma <= 0.f) {
    storage_->register_label(label);
    return;
  }
  gamma /= 4 * C * variance;
  update(sfv, gamma, label, incorrect_label);
}