/* stPbxFunction - determines the probability that a star with the given coordinates
   (coordpar, in lbr format) is in the galaxy defined by the given parameters (bkgdpar)
   verb flags the function to output it's progress as it works
   Return: a double that indicates the probability. Higher values indicate
   higher likelyhood that the star is a part of the galaxy distribution.
   A return < 0 indicates an error:
   -1 - q is 0
   -2 - magnitude of star is NaN
*/
double stPbxFunction(const double* coordpar, const double* bpars)
{
    double xg[3];
    double alpha, q, delta, P;
    double r0, rg;

    lbr2xyz(coordpar, xg);

    /* alpha more negative, background falls off faster
     * q bigger means more squashed galaxy
     * r0, z0 defines the cylinder in the middle of the galaxy that will be empty
     * maxz integrate up to in steps of dz, numerical consideration */
    alpha = bpars[0];
    q     = bpars[1];
    r0    = bpars[2];
    delta = bpars[3];

    /* if q is 0, there is no probability */
    if (q == 0) return -1;

    /* background probability */
    rg = sqrt(xg[0] * xg[0] + xg[1] * xg[1] + (xg[2] / q) * (xg[2] / q));

    P = 1 / (pow(rg, alpha) * pow(rg + r0, 3 - alpha + delta));

    return P;
}
static inline mwvector streamC(const AstronomyParameters* ap, int wedge, real mu, real r)
{
    LB lb;
    mwvector lbr;

    lb = gc2lb(wedge, mu, 0.0);

    L(lbr) = LB_L(lb);
    B(lbr) = LB_B(lb);
    R(lbr) = r;
    W(lbr) = 0.0;
    return lbr2xyz(ap, lbr);
}
static void separation(FILE* f,
                       const AstronomyParameters* ap,
                       const SeparationResults* results,
                       const mwmatrix cmatrix,
                       StreamStats* ss,
                       const real* st_probs,
                       real bg_prob,
                       real epsilon_b,
                       mwvector current_star_point)
{
    int s_ok;
    mwvector starxyz;
    mwvector starxyzTransform;

    mwvector xsun = ZERO_VECTOR;
    X(xsun) = ap->m_sun_r0;

    if (twoPanel)
        twoPanelSeparation(ap, results, ss, st_probs, bg_prob, epsilon_b);
    else
        nonTwoPanelSeparation(ss, ap->number_streams);

    /* determine if star with sprob should be put into stream */
    s_ok = prob_ok(ss, ap->number_streams);
    if (s_ok >= 1)
        ss[s_ok-1].q++;

    starxyz = lbr2xyz(ap, current_star_point);
    starxyzTransform = transform_point(ap, starxyz, cmatrix, xsun);

    if (f)
    {
        fprintf(f,
                "%d %lf %lf %lf\n",
                s_ok,
                X(starxyzTransform), Y(starxyzTransform), Z(starxyzTransform));
    }
}
/* stPsgFunction - determines the probability that a star with given coordinates (coordpar, in lbr format)
   is a part of a stream defined by the given parameters (pars). It is assumed that lbr coordinates
   are solar-centered and xyz coordinates are galactic centered.
   verb flags the function to output it's work as it executes.
   Return: a double value is returned indicating the probability that the star is in the stream.
   A higher value indicates a higher probability.
   If a value < 0 is returned, an error occured.
   -1 - a parameters is NaN
   -2 - an error occured in the call to lbr2stream
*/
double stPsgFunction(const double* coordpar, const double* spars, int wedge, int sgr_coordinates)
{
    //update: allow for new coordinate transforms
    double xyz[3], lbr[3], a[3], c[3];
    double mu, r, theta, phi, sigma;
    double dotted, xyz_norm, prob;
    double ra, dec, lamda, beta, l, b;

    mu = spars[0];
    r = spars[1];
    theta = spars[2];
    phi = spars[3];
    sigma = spars[4];

    //update: convert from mu, nu, r geometry to a and c geometry
    if (sgr_coordinates == 0)
    {
        atGCToEq(mu, 0, &ra, &dec, get_node(), wedge_incl(wedge));
        atEqToGal(ra, dec, &l, &b);
    }
    else if (sgr_coordinates == 1)
    {
        gcToSgr(mu, 0, wedge, &lamda, &beta);
        sgrToGal(lamda, beta, &l, &b);
        // <<<make sure the conversion is correct (check with conversiontester.vb)>>>
        MW_DEBUG(" wedge=%i, mui=%f, nui=0, lamda=%f, beta=%f, l=%f, b=%f", wedge, mu, lamda, beta, l, b);
    }
    else
    {
        fprintf(stderr, "Error: sgr_coordinates not valid");
    }

    lbr[0] = l;
    lbr[1] = b;
    lbr[2] = r;
    lbr2xyz(lbr, c);

    a[0] = sin(theta) * cos(phi);
    a[1] = sin(theta) * sin(phi);
    a[2] = cos(theta);

    //Sigma near 0 so star prob is 0.
    if (sigma > -0.0001 && sigma < 0.0001) return 0;

    lbr2xyz(coordpar, xyz);
    xyz[0] = xyz[0] - c[0];
    xyz[1] = xyz[1] - c[1];
    xyz[2] = xyz[2] - c[2];

    dotted = dotp(a, xyz);
    xyz[0] = xyz[0] - dotted * a[0];
    xyz[1] = xyz[1] - dotted * a[1];
    xyz[2] = xyz[2] - dotted * a[2];

    xyz_norm = norm(xyz);

    MW_DEBUG("dotted: %lf, xyz_norm: %lf, sigma: %lf\n", dotted, xyz_norm, sigma);
    prob = exp( -(xyz_norm * xyz_norm) / 2 / (sigma * sigma) );

    MW_DEBUG("prob before ref: %lf\n", prob);
    return prob;
}
void separation(const char* filename, double background_integral, double* stream_integrals)
{
    int q[ap->number_streams];
    double nstars[ap->number_streams];
    int total;
    double sprob[ap->number_streams];
    double prob_s[ap->number_streams];
    double prob_b;
    double pbx;
    double psg[ap->number_streams];
    double d;
    int twoPanel;
    double** cmatrix;
    double dnormal[3];
    double dortho[3];
    double xsun[3];
    double epsilon_s[ap->number_streams];
    double epsilon_b;
    double star_coords[3];
    double starxyz[3];
    double starxyzTransform[3];
    int s_ok = 0;
    int i, j, retval;
    FILE* file;
    double reff_xr_rp3, *qw_r3_N, *r_point;

    twoPanel = 1;
    for (j = 0; j < ap->number_streams; j++)
    {
        nstars[j] = 0;
        q[j] = 0;
    }
    total = 0;
    prob_ok_init();

    printf("Integral complete.\n Beginning probability calculations...\n");
    file = fopen(filename, "w");

    if (ap->sgr_coordinates == 0)
    {
        stripe_normal(ap->wedge, dnormal);
    }
    else if (ap->sgr_coordinates == 1)
    {
        sgr_stripe_normal(ap->wedge, dnormal);
    }
    else
    {
        printf("Error: ap->sgr_coordinates not valid");
    }

    free_star_points(sp);
    free(sp);
    sp = (STAR_POINTS*)malloc(sizeof(STAR_POINTS));
    retval = read_star_points(star_points_file, sp);
    if (retval)
    {
        fprintf(stderr, "APP: error reading star points: %d\n", retval);
        exit(1);
    }
    printf("read %d stars.\n", sp->number_stars);


    cmatrix = (double**)malloc(sizeof(double*) * 3);
    for (i = 0; i < 3; i++)
        cmatrix[i] = (double*)malloc(sizeof(double) * 3);
    dortho[0] = 0.0;
    dortho[1] = 0.0;
    dortho[2] = 1.0;
    get_transform(dnormal, dortho, cmatrix);

    printf("\nTransformation matrix:\n");
    printf("\t%lf %lf %lf\n", cmatrix[0][0], cmatrix[0][1], cmatrix[0][2]);
    printf("\t%lf %lf %lf\n", cmatrix[1][0], cmatrix[1][1], cmatrix[1][2]);
    printf("\t%lf %lf %lf\n\n", cmatrix[2][0], cmatrix[2][1], cmatrix[2][2]);

    xsun[0] = -8.5;
    xsun[1] = 0.0;
    xsun[2] = 0.0;
    d = dotp(dnormal, xsun);

    printf("==============================================\n");
    printf("bint: %lf", background_integral);
    for (j = 0; j < ap->number_streams; j++)
    {
        printf(", ");
        printf("sint[%d]: %lf", j, stream_integrals[j]);
    }
    printf("\n");

    /*get stream & background weight constants*/
    double denom = 1.0;
    for (j = 0; j < ap->number_streams; j++)
    {
        denom += exp(ap->stream_weights[j]);
    }

    for (j = 0; j < ap->number_streams; j++)
    {
        epsilon_s[j] = exp(ap->stream_weights[j]) / denom;
        printf("epsilon_s[%d]: %lf\n", j, epsilon_s[j]);
    }
    epsilon_b = 1.0 / denom;
    printf("epsilon_b:    %lf\n", epsilon_b);

    r_point = (double*)malloc(sizeof(double) * ap->convolve);
    qw_r3_N = (double*)malloc(sizeof(double) * ap->convolve);

    init_constants(ap);

    printf("initialized constants\n");

    for (i = 0; i < sp->number_stars; i++)
    {
        MW_DEBUG("[%d/%d] setting star coords\n", i, sp->number_stars);
        star_coords[0] = sp->stars[i][0];
        star_coords[1] = sp->stars[i][1];
        star_coords[2] = sp->stars[i][2];
        MW_DEBUG("star_coords: %g %g %g\n", star_coords[0], star_coords[1], star_coords[2]);

        MW_DEBUG("twoPanel: %d\n", twoPanel);

        if (twoPanel == 1)
        {
            MW_DEBUG("setting probability constants\n");
            set_probability_constants(ap->convolve, star_coords[2], r_point, qw_r3_N, &reff_xr_rp3);
            MW_DEBUG("calculating probabilities\n");
            calculate_probabilities(r_point, qw_r3_N, reff_xr_rp3, star_coords, ap, &prob_b, prob_s);
            MW_DEBUG("calculated probabilities\n");

            MW_DEBUG("prob_s: %lf\n", prob_s[0]);
            MW_DEBU("prob_b: %lf\n", prob_b);

            pbx = epsilon_b * prob_b / background_integral;

            for (j = 0; j < ap->number_streams; j++)
            {
                psg[j] = epsilon_s[j] * prob_s[j] / stream_integrals[j];
            }

            MW_DEBUG("pbx: %g\n", pbx);
            MW_DEBUG("psg: %g\n", psg[0]);

            double psgSum = 0;
            for (j = 0; j < ap->number_streams; j++)
            {
                psgSum += psg[j];
            }

            for (j = 0; j < ap->number_streams; j++)
            {
                sprob[j] = psg[j] / (psgSum + pbx);
            }

            MW_DEBUG("sprob: %g\n", sprob[0]);

            for (j = 0; j < ap->number_streams; j++)
            {
                nstars[j] += sprob[j];
            }

            MW_DEBUG("nstars: %g\n", nstars[0]);
        }
        else
        {
            for (j = 0; j < ap->number_streams; j++)
            {
                sprob[j] = 1.0;
                nstars[j] += 1.0;
            }
        }


        /*determine if star with sprob should be put into stream*/
        //for(j = 0; j < ap->number_streams; j++) {
        s_ok = prob_ok(ap->number_streams, sprob);
        //  if (s_ok == 1) {
        //      s_ok += j;
        //      break;
        //  }
        //}

        MW_DEBUG("s_ok: %d\n", s_ok);

        if (s_ok >= 1)
        {
            q[s_ok-1]++;
        }

        lbr2xyz(star_coords, starxyz);
        transform_point(starxyz, cmatrix, xsun, starxyzTransform);

        fprintf(file, "%d %lf %lf %lf\n", s_ok, starxyzTransform[0], starxyzTransform[1], starxyzTransform[2]);
        //free(starxyz);
        //free(starxyzTransform);

        total += 1;

        if ( (total % 10000) == 0 )
            printf("%d\n", total);
    }

    printf("%d total stars\n", total);
    for (j = 0; j < ap->number_streams; j++)
    {
        printf("%lf in stream[%d] (%lf%%)\n", nstars[j], j, (nstars[j] / total * 100));
    }

    for (j = 0; j < ap->number_streams; j++)
    {
        printf("%d stars separated into stream\n", q[j]);
    }
    fclose(file);
    printf("Output written to: %s\n", filename);
    free(r_point);
    free(qw_r3_N);
    free_constants(ap);
}