//--------------------------------------------------------------------------
double Performance(const DVector &ytest_truth, const DMatrix &Ytest_predict,
                   int NumClasses) {

  long n = ytest_truth.GetN();
  if (n != Ytest_predict.GetM() || Ytest_predict.GetN() != NumClasses ) {
    printf("Performance. Error: Size mismatch. Return NAN");
    return NAN;
  }
  if (NumClasses <= 2) {
    printf("Performance. Error: Not multiclass classification. Return NAN");
    return NAN;
  }
  double perf = 0.0;

  // Compute ytest_predict
  DVector ytest_predict(n);
  double *y = ytest_predict.GetPointer();
  for (long i = 0; i < n; i++) {
    DVector row(NumClasses);
    Ytest_predict.GetRow(i, row);
    long idx = -1;
    row.Max(idx);
    y[i] = (double)idx;
  }

  // Accuracy
  double *y1 = ytest_truth.GetPointer();
  double *y2 = ytest_predict.GetPointer();
  for (long i = 0; i < n; i++) {
    perf += ((int)y1[i])==((int)y2[i]) ? 1.0 : 0.0;
  }
  perf = perf/n * 100.0;

  return perf;
}
//--------------------------------------------------------------------------
int main(int argc, char **argv) {

  // Temporary variables
  long int i, j, k, ii, idx = 1;
  double ElapsedTime;

  // Arguments
  int NumThreads = atoi(argv[idx++]);
  char *FileTrain = argv[idx++];
  char *FileTest = argv[idx++];
  int NumClasses = atoi(argv[idx++]);
  int d = atoi(argv[idx++]);
  int r = atoi(argv[idx++]);
  int Num_lambda = atoi(argv[idx++]);
  double *List_lambda = (double *)malloc(Num_lambda*sizeof(double));
  for (ii = 0; ii < Num_lambda; ii++) {
    List_lambda[ii] = atof(argv[idx++]);
  }
  int Num_sigma = atoi(argv[idx++]);
  double *List_sigma = (double *)malloc(Num_sigma*sizeof(double));
  for (i = 0; i < Num_sigma; i++) {
    List_sigma[i] = atof(argv[idx++]);
  }
  int MAXIT = atoi(argv[idx++]);
  double TOL = atof(argv[idx++]);
  bool verbose = atoi(argv[idx++]);

  // Threading
#ifdef USE_OPENBLAS
  openblas_set_num_threads(NumThreads);
#elif USE_ESSL

#elif USE_OPENMP
  omp_set_num_threads(NumThreads);
#else
  NumThreads = 1; // To avoid compiler warining of unused variable
#endif

  PREPARE_CLOCK(1);
  START_CLOCK;

  // Read in X = Xtrain (n*d), y = ytrain (n*1),
  //     and X0 = Xtest (m*d), y0 = ytest (m*1)
  DPointArray Xtrain;        // read all data points from train
  DPointArray Xtest;        // read all data points from test
  DVector ytrain;           // Training labels
  DVector ytest;            // Testing labels (ground truth)
  DVector ytest_predict;    // Predictions

  if (ReadData(FileTrain, Xtrain, ytrain, d) == 0) {
    return -1;
  }
  if (ReadData(FileTest, Xtest, ytest, d) == 0) {
    return -1;
  }

  END_CLOCK;
  ElapsedTime = ELAPSED_TIME;
  printf("OneVsAll: time loading data = %g seconds\n", ElapsedTime); fflush(stdout);

  // For multiclass classification, need to convert a single vector
  // ytrain to a matrix Ytrain. The "predictions" are stored in the
  // corresponding matrix Ytest_predict. The vector ytest_predict is
  DMatrix Ytrain;
  ConvertYtrain(ytrain, Ytrain, NumClasses);

  int Seed = 0; // initialize seed as zero
  // Loop over List_lambda
  for (ii = 0; ii < Num_lambda; ii++) {
    double lambda = List_lambda[ii];
    // Loop over List_sigma
    for (k = 0; k < Num_sigma; k++) {
    double sigma = List_sigma[k];
    // Seed the RNG
    srandom(Seed);

    START_CLOCK;
    // Generate feature matrix Xdata_randbin given Xdata
    vector< vector< pair<int,double> > > instances_old, instances_new;
    long Xtrain_N = Xtrain.GetN();
    for(i=0;i<Xtrain_N;i++){
      instances_old.push_back(vector<pair<int,double> >());
      for(j=0;j<d;j++){
        int index = j+1;
        double *myXtrain = Xtrain.GetPointer();
        double  myXtrain_feature = myXtrain[j*Xtrain_N+i];
        if (myXtrain_feature != 0)
          instances_old.back().push_back(pair<int,double>(index, myXtrain_feature));
      }
    }
    long Xtest_N = Xtest.GetN();
    for(i=0;i<Xtest_N;i++){
      instances_old.push_back(vector<pair<int,double> >());
      for(j=0;j<d;j++){
        int index = j+1;
        double *myXtest = Xtest.GetPointer();
        double  myXtest_feature = myXtest[j*Xtest_N+i];
        if (myXtest_feature != 0)
          instances_old.back().push_back(pair<int,double>(index, myXtest_feature));
      }
    }
    END_CLOCK;
    printf("Train. RandBin: Time (in seconds) for converting data format: %g\n", ELAPSED_TIME);fflush(stdout);
     
     // add 0 feature for Enxu's code
    START_CLOCK;
    random_binning_feature(d+1, r, instances_old, instances_new, sigma);
    END_CLOCK;
    printf("Train. RandBin: Time (in seconds) for generating random binning features: %g\n", ELAPSED_TIME);fflush(stdout);

    START_CLOCK;
    SPointArray Xdata_randbin;  // Generate random binning features
    long int nnz = r*(Xtrain_N + Xtest_N);
    long int dd = 0;
    for(i = 0; i < instances_new.size(); i++){
      if(dd < instances_new[i][r-1].first)
        dd = instances_new[i][r-1].first;
    }
    Xdata_randbin.Init(Xtrain_N+Xtest_N, dd, nnz);
    long int ind = 0;
    long int *mystart = Xdata_randbin.GetPointerStart();
    int *myidx = Xdata_randbin.GetPointerIdx();
    double *myX = Xdata_randbin.GetPointerX();
    for(i = 0; i < instances_new.size(); i++){
      if (i == 0)
        mystart[i] = 0;
      else
        mystart[i] = mystart[i-1] + r;
      for(j = 0; j < instances_new[i].size(); j++){
        myidx[ind] = instances_new[i][j].first-1;
        myX[ind] = instances_new[i][j].second;
        ind++;
      }
    }
    mystart[i] = nnz; // mystart has a length N+1
    // generate random binning features for Xtrain and Xtest
    SPointArray Xtrain;         // Training points
    SPointArray Xtest;          // Testing points
    long Row_start = 0;
    Xdata_randbin.GetSubset(Row_start, Xtrain_N,Xtrain);
    Xdata_randbin.GetSubset(Xtrain_N,Xtest_N,Xtest);
    Xdata_randbin.ReleaseAllMemory();
    END_CLOCK;
    printf("Train. RandBin: Time (in seconds) for converting data format back: %g\n", ELAPSED_TIME);fflush(stdout);
    printf("OneVsAll: n train = %ld, m test = %ld, r = %d, D = %ld, Gamma = %f, num threads = %d\n", Xtrain_N, Xtest_N, r, dd, sigma, NumThreads); fflush(stdout);

    // solve (Z'Z + lambdaI)w = Z'y, note that we never explicitly form
    // Z'Z since Z is a large sparse matrix N*dd
    START_CLOCK;
    int m = Ytrain.GetN(); // number of classes
    long N = Xtrain.GetN(); // number of training points
    long NN = Xtest.GetN(); // number of training points
    long M = Xtrain.GetD(); // dimension of randome binning features
    DMatrix Ytest_predict(NN,m);
    DMatrix W(M,m);
    SPointArray EYE;
    EYE.Init(M,M,M);
    mystart = EYE.GetPointerStart();
    myidx = EYE.GetPointerIdx();
    myX = EYE.GetPointerX();
    for(i=0;i<M;i++){
      mystart[i] = i;
      myidx[i] = i;
      myX[i] = 1;
    }
    mystart[i] = M+1; // mystart has a length N+1
    for (i = 0; i < m; i++) {
      DVector w;
      w.Init(M);
      DVector ytrain, yy;
      Ytrain.GetColumn(i, ytrain);
      Xtrain.MatVec(ytrain, yy, TRANSPOSE);
      double NormRHS = yy.Norm2();
      PCG pcg_solver;
      pcg_solver.Solve<SPointArray, SPointArray>(Xtrain, yy, w, EYE, MAXIT, TOL, 1);
      if (verbose) {
        int Iter = 0;
        const double *ResHistory = pcg_solver.GetResHistory(Iter);
        printf("RLCM::Train, PCG. iteration = %d, Relative residual = %g\n",
          Iter, ResHistory[Iter-1]/NormRHS);fflush(stdout);
      }

      pcg_solver.GetSolution(w);
      W.SetColumn(i, w);
    }
    END_CLOCK;
    printf("Train. RandBin: Time (in seconds) for solving linear system solution: %g\n", ELAPSED_TIME);fflush(stdout);

    // y = Xtest*W = z(x)'*w
    START_CLOCK;
    Xtest.MatMat(W,Ytest_predict,NORMAL,NORMAL);
    double accuracy = Performance(ytest, Ytest_predict, NumClasses);
    END_CLOCK;
    ElapsedTime = ELAPSED_TIME;
    printf("Test. RandBin: param = %g %g, perf = %g, time = %g\n", sigma, lambda, accuracy, ElapsedTime); fflush(stdout);

  }// End loop over List_sigma
  }// End loop over List_lambda

  // Clean up
  free(List_sigma);
  free(List_lambda);

  return 0;
}