/** Forms the quadrature phase signal (squashogram)
* @param ws :: [input] workspace containing the measured spectra
* @param phase :: [input] table workspace containing the detector phases
* @param n0 :: [input] vector containing the normalization constants
* @return :: workspace containing the quadrature phase signal
*/
API::MatrixWorkspace_sptr
PhaseQuadMuon::squash(const API::MatrixWorkspace_sptr &ws,
                      const API::ITableWorkspace_sptr &phase,
                      const std::vector<double> &n0) {

  // Poisson limit: below this number we consider we don't have enough
  // statistics
  // to apply sqrt(N). This is an arbitrary number used in the original code
  // provided by scientists
  double poissonLimit = 30.;

  size_t nspec = ws->getNumberHistograms();
  size_t npoints = ws->blocksize();

  // Muon life time in microseconds
  double muLife = PhysicalConstants::MuonLifetime * 1e6;

  if (n0.size() != nspec) {
    throw std::invalid_argument("Invalid normalization constants");
  }

  // Get the maximum asymmetry
  double maxAsym = 0.;
  for (size_t h = 0; h < nspec; h++) {
    if (phase->Double(h, 1) > maxAsym) {
      maxAsym = phase->Double(h, 1);
    }
  }
  if (maxAsym == 0.0) {
    throw std::invalid_argument("Invalid detector asymmetries");
  }

  std::vector<double> aj, bj;
  {
    // Calculate coefficients aj, bj

    double sxx = 0;
    double syy = 0;
    double sxy = 0;

    for (size_t h = 0; h < nspec; h++) {
      double asym = phase->Double(h, 1) / maxAsym;
      double phi = phase->Double(h, 2);
      double X = n0[h] * asym * cos(phi);
      double Y = n0[h] * asym * sin(phi);
      sxx += X * X;
      syy += Y * Y;
      sxy += X * Y;
    }

    double lam1 = 2 * syy / (sxx * syy - sxy * sxy);
    double mu1 = 2 * sxy / (sxy * sxy - sxx * syy);
    double lam2 = 2 * sxy / (sxy * sxy - sxx * syy);
    double mu2 = 2 * sxx / (sxx * syy - sxy * sxy);
    for (size_t h = 0; h < nspec; h++) {
      double asym = phase->Double(h, 1) / maxAsym;
      double phi = phase->Double(h, 2);
      double X = n0[h] * asym * cos(phi);
      double Y = n0[h] * asym * sin(phi);
      aj.push_back((lam1 * X + mu1 * Y) * 0.5);
      bj.push_back((lam2 * X + mu2 * Y) * 0.5);
    }
  }

  // First X value
  double X0 = ws->x(0).front();

  // Create and populate output workspace
  API::MatrixWorkspace_sptr ows = API::WorkspaceFactory::Instance().create(
      "Workspace2D", 2, npoints + 1, npoints);

  // X
  ows->setSharedX(0, ws->sharedX(0));
  ows->setSharedX(1, ws->sharedX(0));

  // Phase quadrature
  auto &realY = ows->mutableY(0);
  auto &imagY = ows->mutableY(1);
  auto &realE = ows->mutableE(0);
  auto &imagE = ows->mutableE(1);

  for (size_t i = 0; i < npoints; i++) {
    for (size_t h = 0; h < nspec; h++) {

      // (X,Y,E) with exponential decay removed
      const double X = ws->x(h)[i];
      const double Y = ws->y(h)[i] - n0[h] * exp(-(X - X0) / muLife);
      const double E = (ws->y(h)[i] > poissonLimit)
                           ? ws->e(h)[i]
                           : sqrt(n0[h] * exp(-(X - X0) / muLife));

      realY[i] += aj[h] * Y;
      imagY[i] += bj[h] * Y;
      realE[i] += aj[h] * aj[h] * E * E;
      imagE[i] += bj[h] * bj[h] * E * E;
    }
    realE[i] = sqrt(realE[i]);
    imagE[i] = sqrt(imagE[i]);

    // Regain exponential decay
    const double X = ws->getSpectrum(0).x()[i];
    const double e = exp(-(X - X0) / muLife);
    realY[i] /= e;
    imagY[i] /= e;
    realE[i] /= e;
    imagE[i] /= e;
  }

  return ows;
}
/** Forms the quadrature phase signal (squashogram)
 * @param ws :: [input] workspace containing the measured spectra
 * @param phase :: [input] table workspace containing the detector phases
 * @param n0 :: [input] vector containing the normalization constants
 * @return :: workspace containing the quadrature phase signal
 */
API::MatrixWorkspace_sptr
PhaseQuadMuon::squash(const API::MatrixWorkspace_sptr &ws,
                      const API::ITableWorkspace_sptr &phase,
                      const std::vector<double> &n0) {

  // Poisson limit: below this number we consider we don't have enough
  // statistics
  // to apply sqrt(N). This is an arbitrary number used in the original code
  // provided by scientists
  const double poissonLimit = 30.;

  // Muon life time in microseconds
  const double muLife = PhysicalConstants::MuonLifetime * 1e6;

  const size_t nspec = ws->getNumberHistograms();

  if (n0.size() != nspec) {
    throw std::invalid_argument("Invalid normalization constants");
  }

  auto names = phase->getColumnNames();
  for (auto &name : names) {
    std::transform(name.begin(), name.end(), name.begin(), ::tolower);
  }
  auto phaseIndex = findName(phaseNames, names);
  auto asymmetryIndex = findName(asymmNames, names);

  // Get the maximum asymmetry
  double maxAsym = 0.;
  for (size_t h = 0; h < nspec; h++) {
    if (phase->Double(h, asymmetryIndex) > maxAsym &&
        phase->Double(h, asymmetryIndex) != ASYMM_ERROR) {
      maxAsym = phase->Double(h, asymmetryIndex);
    }
  }

  if (maxAsym == 0.0) {
    throw std::invalid_argument("Invalid detector asymmetries");
  }
  std::vector<bool> emptySpectrum;
  emptySpectrum.reserve(nspec);
  std::vector<double> aj, bj;
  {
    // Calculate coefficients aj, bj

    double sxx = 0.;
    double syy = 0.;
    double sxy = 0.;
    for (size_t h = 0; h < nspec; h++) {
      emptySpectrum.push_back(
          std::all_of(ws->y(h).begin(), ws->y(h).end(),
                      [](double value) { return value == 0.; }));
      if (!emptySpectrum[h]) {
        const double asym = phase->Double(h, asymmetryIndex) / maxAsym;
        const double phi = phase->Double(h, phaseIndex);
        const double X = n0[h] * asym * cos(phi);
        const double Y = n0[h] * asym * sin(phi);
        sxx += X * X;
        syy += Y * Y;
        sxy += X * Y;
      }
    }

    const double lam1 = 2 * syy / (sxx * syy - sxy * sxy);
    const double mu1 = 2 * sxy / (sxy * sxy - sxx * syy);
    const double lam2 = 2 * sxy / (sxy * sxy - sxx * syy);
    const double mu2 = 2 * sxx / (sxx * syy - sxy * sxy);
    for (size_t h = 0; h < nspec; h++) {
      if (emptySpectrum[h]) {
        aj.push_back(0.0);
        bj.push_back(0.0);
      } else {
        const double asym = phase->Double(h, asymmetryIndex) / maxAsym;
        const double phi = phase->Double(h, phaseIndex);
        const double X = n0[h] * asym * cos(phi);
        const double Y = n0[h] * asym * sin(phi);
        aj.push_back((lam1 * X + mu1 * Y) * 0.5);
        bj.push_back((lam2 * X + mu2 * Y) * 0.5);
      }
    }
  }

  const size_t npoints = ws->blocksize();
  // Create and populate output workspace
  API::MatrixWorkspace_sptr ows =
      API::WorkspaceFactory::Instance().create(ws, 2, npoints + 1, npoints);

  // X
  ows->setSharedX(0, ws->sharedX(0));
  ows->setSharedX(1, ws->sharedX(0));

  // Phase quadrature
  auto &realY = ows->mutableY(0);
  auto &imagY = ows->mutableY(1);
  auto &realE = ows->mutableE(0);
  auto &imagE = ows->mutableE(1);

  const auto xPointData = ws->histogram(0).points();
  // First X value
  const double X0 = xPointData.front();

  // calculate exponential decay outside of the loop
  std::vector<double> expDecay = xPointData.rawData();
  std::transform(expDecay.begin(), expDecay.end(), expDecay.begin(),
                 [X0, muLife](double x) { return exp(-(x - X0) / muLife); });

  for (size_t i = 0; i < npoints; i++) {
    for (size_t h = 0; h < nspec; h++) {
      if (!emptySpectrum[h]) {
        // (X,Y,E) with exponential decay removed
        const double X = ws->x(h)[i];
        const double exponential = n0[h] * exp(-(X - X0) / muLife);
        const double Y = ws->y(h)[i] - exponential;
        const double E =
            (ws->y(h)[i] > poissonLimit) ? ws->e(h)[i] : sqrt(exponential);

        realY[i] += aj[h] * Y;
        imagY[i] += bj[h] * Y;
        realE[i] += aj[h] * aj[h] * E * E;
        imagE[i] += bj[h] * bj[h] * E * E;
      }
    }
    realE[i] = sqrt(realE[i]);
    imagE[i] = sqrt(imagE[i]);

    // Regain exponential decay
    realY[i] /= expDecay[i];
    imagY[i] /= expDecay[i];
    realE[i] /= expDecay[i];
    imagE[i] /= expDecay[i];
  }

  // New Y axis label
  ows->setYUnit("Asymmetry");

  return ows;
}