//////////////
// eBAT API //
//////////////
int crypto_dh_keypair(uchar *pk, uchar *sk)
{
  int i;
  mpz_t key;
  Kfield K;
  KSpoint base_point, res;
  KSparam KS;
  
  // Get random bytes for sk
  randombytes(sk,SECRETKEY_BYTES);
  sk[0] &= ~1; // clear bits 0
  
  // init field and base_point
  Kfield_init();
  KSinit(K, base_point);
  Kset_uipoly(base_point->x, 1);
  Kset_uipoly(base_point->y, 1);
  Kset_uipoly(base_point->z, 4);
  {
	  mpz_t tz;
	  mpz_init_set_str(tz,"908681679267597915035722095941517",10);
	  Kset_uipoly_wide(base_point->t, tz[0]._mp_d, tz[0]._mp_size);
	  mpz_clear(tz);
  }

  // put the key in an mpz
  mpz_init_set_ui(key, 0);
  for (i = SECRETKEY_BYTES-1; i > 0; --i) {
    mpz_add_ui(key, key, sk[i]);
    mpz_mul_2exp(key, key, 8);
  }
  mpz_add_ui(key, key, sk[0]);

  // Scalar multiplication
  KSinit(K, res);
  StandardKS(K, KS);
  KSmul(K, res, base_point, key, KS);
  Kinv(res->x, res->x);
  Kmul(res->y, res->y, res->x);
  Kmul(res->z, res->z, res->x);
  Kmul(res->t, res->t, res->x);
  
  // put the result in pk
  elt2bytes(pk, res->y);
  elt2bytes(pk+16, res->z);
  elt2bytes(pk+32, res->t);
 
  // clean
  mpz_clear(key);
  KSclear(K, base_point);
  KSclear(K, res);
  Kfield_clear();
  
  return 0;
}
int crypto_dh(uchar *s,
    const uchar *pk,
    const uchar *sk)
{
  int i;
  mpz_t key;
  Kfield K;
  KSpoint base_point, res;
  KSparam KS;
  
  // init field and base_point
  Kfield_init();
  KSinit(K, base_point);

  // read base_point from {pk,pklen}
  bytes2elt(base_point->y, pk);
  bytes2elt(base_point->z, pk+16);
  bytes2elt(base_point->t, pk+32);
  base_point->x[0] = 1UL;
  for (i = 1; i < LIMB_PER_ELT; ++i)
    base_point->x[i] = 0UL;

  // put the key in an mpz
  mpz_init_set_ui(key, 0);
  for (i = SECRETKEY_BYTES-1; i > 0; --i) {
    mpz_add_ui(key, key, sk[i]);
    mpz_mul_2exp(key, key, 8);
  }
  mpz_add_ui(key, key, sk[0]);

  // Scalar multiplication
  KSinit(K, res);
  StandardKS(K, KS);
  KSmul(K, res, base_point, key, KS);
  Kinv(res->x, res->x);
  Kmul(res->y, res->y, res->x);
  Kmul(res->z, res->z, res->x);
  Kmul(res->t, res->t, res->x);

  
  // put the result in s
  elt2bytes(s, res->y);
  elt2bytes(s+16, res->z);
  elt2bytes(s+32, res->t);
 
  // clean
  mpz_clear(key);
  KSclear(K, base_point);
  KSclear(K, res);
  Kfield_clear();
  
  return 0;
}  
void ElectronScattering::CEDA::collect(const ElectronScattering& es, int iq, const diagMatrix& chiKS0, diagMatrix& num, diagMatrix& den)
{	int nBands = Fsum.nRows();
	//MPI accumulate:
	Fsum.allReduce(MPIUtil::ReduceSum);
	FEsum.allReduce(MPIUtil::ReduceSum);
	for(int b=0; b<nBands; b++)
	{	FNLsum[b].allReduce(MPIUtil::ReduceSum);
		oNum[b].allReduce(MPIUtil::ReduceSum);
		oDen[b].allReduce(MPIUtil::ReduceSum);
	}
	//Convert to cumulative contributions:
	for(int b=1; b<nBands; b++)
	{	Fsum[b] += Fsum[b-1];
		FEsum[b] += FEsum[b-1];
		FNLsum[b] += FNLsum[b-1];
		oNum[b] += oNum[b-1];
		oDen[b] += oDen[b-1];
	}
	//Calculate actual numerator and denominator terms:
	const Basis& basisChi = es.basisChi[iq];
	const GridInfo& gInfo = *(basisChi.gInfo);
	double qWeight = es.qmesh[iq].weight;
	const vector3<>& q = es.qmesh[iq].k;
	int nbasis = basisChi.nbasis;
	double detRsq = std::pow(gInfo.detR, 2);
	diagMatrix K(nbasis), Kinv(nbasis), absKscrMinusK(nbasis);
	const double tol = 1e-8;
	for(int n=0; n<nbasis; n++)
	{	Kinv[n] = gInfo.GGT.metric_length_squared(q + basisChi.iGarr[n]) / (4*M_PI);
		K[n] = (fabs(Kinv[n])<tol) ? 0. : 1./Kinv[n];
		double invKscr = Kinv[n] - chiKS0[n];
		absKscrMinusK[n] = (fabs(invKscr)<tol) ? 0. : fabs(1./invKscr - K[n]);
	}
	diagMatrix wG = qWeight * absKscrMinusK * K;
	double wSum = trace(wG);
	double wKinvSum = dot(wG, Kinv);
	for(int b=0; b<nBands; b++)
	{	num[b] += wSum*FEsum[b] - detRsq*dot(wG,oNum[b]) + dot(wG,FNLsum[b]) + (2*M_PI)*wKinvSum*Fsum[b];
		den[b] += wSum*Fsum[b] - detRsq*dot(wG,oDen[b]);
	}
}