//----------------------------------------------------------------------
  void ProbitBartPosteriorSampler::impute_latent_data_point(DataType *data) {
    double eta = data->prediction();
    int n = data->n();
    int number_positive = data->y();
    int number_negative = n - number_positive;

    double sum_of_probits = 0;
    if (number_positive > 5) {
      double mean = 0;
      double variance = 1;
      trun_norm_moments(eta, 1, 0, true, &mean, &variance);
      sum_of_probits += rnorm_mt(rng(),
                                 number_positive * mean,
                                 sqrt(number_positive * variance));
    } else {
      for (int i = 0; i < number_positive; ++i) {
        sum_of_probits += rtrun_norm_mt(rng(), eta, 1, 0, true);
      }
    }

    if (number_negative > 5) {
      double mean = 0;
      double variance = 1;
      trun_norm_moments(eta, 1, 0, false, &mean, &variance);
      sum_of_probits += rnorm_mt(rng(),
                                 number_negative * mean,
                                 sqrt(number_negative * variance));
    } else {
      for (int i = 0; i < number_negative; ++i) {
        sum_of_probits += rtrun_norm_mt(rng(), eta, 1, 0, false);
      }
    }
    data->set_sum_of_residuals(sum_of_probits - (n * eta));
  }
  double BinomialProbitDataImputer::impute(RNG &rng, double number_of_trials,
                                           double number_of_successes,
                                           double eta) const {
    int64_t n = lround(number_of_trials);
    int64_t y = lround(number_of_successes);
    if (y < 0 || n < 0) {
      report_error(
          "Negative values not allowed in "
          "BinomialProbitDataImputer::impute().");
    }
    if (y > n) {
      report_error(
          "Success count exceeds trial count in "
          "BinomialProbitDataImputer::impute.");
    }
    double mean, variance;

    double ans = 0;
    if (y > clt_threshold_) {
      trun_norm_moments(eta, 1, 0, true, &mean, &variance);
      // If we draw y deviates from the same truncated normal and add
      // them up we'll have a normal with mean (y * mean) and variance
      // (y * variance).
      ans += rnorm_mt(rng, y * mean, sqrt(y * variance));
    } else {
      for (int i = 0; i < y; ++i) {
        // TODO: If y is large-ish but not quite
        // clt_threshold_ then we might waste some time here
        // constantly rebuilding the same TnSampler object.
        ans += rtrun_norm_mt(rng, eta, 1, 0, true);
      }
    }

    if (n - y > clt_threshold_) {
      trun_norm_moments(eta, 1, 0, false, &mean, &variance);
      ans += rnorm_mt(rng, (n - y) * mean, sqrt((n - y) * variance));
    } else {
      for (int i = 0; i < n - y; ++i) {
        ans += rtrun_norm_mt(rng, eta, 1, 0, false);
      }
    }
    return ans;
  }