示例#1
0
bool LineSearcher::MoreThuenteLineSearch(
    DenseVector &param, DenseVector &direc, DenseVector &grad, double finit,
    double &stepsize,
    std::function<double(DenseVector &, DenseVector &)> &funcgrad) {

  itercnt_ = 0;
  int brackt, stage1, uinfo = 0;
  double dg;
  double stx, fx, dgx;
  double sty, fy, dgy;
  double fxm, dgxm, fym, dgym, fm, dgm;
  double ftest1, dginit, dgtest;
  double width, prev_width;
  double stmin, stmax;
  double fval;

  if (stepsize < 0) {
    LOG(FATAL) << "Stepsize less than 0";
    return false;
  }

  dginit = direc.dot(grad);
  if (dginit > 0) {
    LOG(FATAL) << "Direction not decent";
    return false;
  }

  if (tparam_.size() != param.size()) {
    tparam_.resize(param.size());
  }

  /* Initialize local variables. */
  brackt = 0;
  stage1 = 1;
  dgtest = alpha_ * dginit;
  width = maxstep_ - minstep_;
  prev_width = 2.0 * width;

  stx = sty = 0.;
  fx = fy = finit;
  dgx = dgy = dginit;

  while (itercnt_ < maxtries_) {
    /*
    Set the minimum and maximum steps to correspond to the
    present interval of uncertainty.
    */
    if (brackt) {
      stmin = std::min(stx, sty);
      stmax = std::min(stx, sty);
    } else {
      stmin = stx;
      stmax = stepsize + 4.0 * (stepsize - stx);
    }

    /* Clip the step in the range of [minstep_, maxstep_]. */
    if (stepsize < minstep_)
      stepsize = minstep_;
    if (stepsize > maxstep_)
      stepsize = maxstep_;

    /*
    If an unusual termination is to occur then let
    stepsize be the lowest point obtained so far.
    */
    if ((brackt && ((stepsize <= stmin || stepsize >= stmax) ||
                    (itercnt_ + 1 >= maxtries_) || uinfo != 0)) ||
        (brackt && (stmax - stmin <= parameps_ * stmax))) {
      stepsize = stx;
    }

    tparam_ = param + stepsize * direc;
    fval = funcgrad(tparam_, grad);
    dg = grad.dot(direc);

    ftest1 = finit + stepsize * dgtest;
    ++itercnt_;

    /* Test for errors and convergence. */
    if (brackt && ((stepsize <= stmin || stmax <= stepsize) || uinfo != 0)) {
      /* Rounding errors prevent further progress. */
      return false;
    }
    if (stepsize == maxstep_ && fval <= ftest1 && dg <= dgtest) {
      /* The step is the maximum value. */
      return false;
    }
    if (stepsize == minstep_ && (ftest1 < fval || dgtest <= dg)) {
      /* The step is the minimum value. */
      return false;
    }
    if (brackt && (stmax - stmin) <= parameps_ * stmax) {
      /* Relative width of the interval of uncertainty is at most xtol. */
      return false;
    }
    if (maxtries_ <= itercnt_) {
      /* Maximum number of iteration. */
      return false;
    }

    if (fval <= ftest1 && std::fabs(dg) <= beta_ * (-dginit)) {
      /* The sufficient decrease condition and the directional derivative
       * condition hold. */
      param.swap(tparam_);
      return true;
    }

    /*
    In the first stage we seek a step for which the modified
    function has a nonpositive value and nonnegative derivative.
    */
    if (stage1 && fval <= ftest1 && std::min(alpha_, beta_) * dginit <= dg) {
      stage1 = 0;
    }

    /*
    A modified function is used to predict the step only if
    we have not obtained a step for which the modified
    function has a nonpositive function value and nonnegative
    derivative, and if a lower function value has been
    obtained but the decrease is not sufficient.
    */
    if (stage1 && ftest1 < fval && fval <= fx) {
      /* Define the modified function and derivative values. */
      fm = fval - stepsize * dgtest;
      fxm = fx - stx * dgtest;
      fym = fy - sty * dgtest;
      dgm = dg - dgtest;
      dgxm = dgx - dgtest;
      dgym = dgy - dgtest;

      /*
      Call update_trial_interval() to update the interval of
      uncertainty and to compute the new step.
      */
      uinfo =
          update_trial_interval(&stx, &fxm, &dgxm, &sty, &fym, &dgym, &stepsize,
                                &fm, &dgm, stmin, stmax, &brackt);

      /* Reset the function and gradient values for f. */
      fx = fxm + stx * dgtest;
      fy = fym + sty * dgtest;
      dgx = dgxm + dgtest;
      dgy = dgym + dgtest;
    } else {
      /*
      Call update_trial_interval() to update the interval of
      uncertainty and to compute the new step.
      */
      uinfo = update_trial_interval(&stx, &fx, &dgx, &sty, &fy, &dgy, &stepsize,
                                    &fval, &dg, stmin, stmax, &brackt);
    }

    /*
    Force a sufficient decrease in the interval of uncertainty.
    */
    if (brackt) {
      if (0.66 * prev_width <= fabs(sty - stx)) {
        stepsize = stx + 0.5 * (sty - stx);
      }
      prev_width = width;
      width = std::fabs(sty - stx);
    }
  }

  return false;
}
示例#2
0
bool LineSearcher::BackTrackLineSearch(
    DenseVector &param, DenseVector &direc, DenseVector &grad, double finit,
    double &stepsize,
    std::function<double(DenseVector &, DenseVector &)> &funcgrad) {
  itercnt_ = 0;
  double stepupdate;
  double dginit = direc.dot(grad), dgtest, fval, dgval;
  const double stepshrink = 0.5, stepexpand = 2.1;
  if (dginit > 0) {
    LOG(FATAL) << "initial direction is not a decent direction";
    return false;
  }

  if (tparam_.size() != param.size()) {
    tparam_.resize(param.size());
  }

  dgtest = dginit * alpha_;
  while (itercnt_ < maxtries_) {
    tparam_ = param + stepsize * direc;
    fval = funcgrad(tparam_, grad);

    if (fval > finit + stepsize * dgtest) {
      stepupdate = stepshrink;
    } else {
      if (lscondtype_ == LineSearchConditionType::Armijo) {
        break;
      }

      dgval = direc.dot(grad);
      if (dgval < beta_ * dginit) {
        stepupdate = stepexpand;
      } else {
        if (lscondtype_ == LineSearchConditionType::Wolfe) {
          break;
        }

        if (dgval > -beta_ * dginit) {
          stepupdate = stepshrink;
        } else {
          break;
        }
      }
    }

    if (stepsize < minstep_) {
      LOG(ERROR) << "Small than smallest step size";
      return false;
    }

    if (stepsize > maxstep_) {
      LOG(ERROR) << "Large than largest step size";
      return false;
    }

    stepsize *= stepupdate;
    ++itercnt_;
  }

  if (itercnt_ >= maxtries_) {
    LOG(ERROR) << "Exceed Maximum number of iteration count";
    return false;
  }

  param.swap(tparam_);
  return true;
}