/*------Destruction------*/
extern __attribute__ ((destructor)) void global_pointers_destruction(){
    printf("-----Destructing:\n");
    if(cublasXt_handle!=NULL)   {
        cublasStatus_t status=cublasXtDestroy(cublasXt_handle);
        if (status==CUBLAS_STATUS_NOT_INITIALIZED) {
            Blasx_Debug_Output("the cublasXtDestroy library was not initialized\n");
        }else if(status==CUBLAS_STATUS_SUCCESS){
            Blasx_Debug_Output("the cublasXt library was successfully destroied\n");
        }
    }
    if (cpublas_handle!=NULL)   {
        int error = dlclose(cpublas_handle);
        if(error == 0) {
            Blasx_Debug_Output("cpublas_handle closed\n");
        }else{
            Blasx_Debug_Output("dlclose error\n");
        }
    }
    if (is_blasx_enable == 1) {
        Blasx_Debug_Output("dest blasx\n");
        blasx_resource_dest( SYS_GPUS, handles_SGEMM, streams_SGEMM, event_SGEMM, C_dev_SGEMM);
        blasx_resource_dest( SYS_GPUS, handles_DGEMM, streams_DGEMM, event_DGEMM, C_dev_DGEMM);
    }
}
int main(int argc, char *argv[])
{
  MatMulArgs matMulArgs;
  matMulArgs.processArgs(argc, argv);

  size_t matrixAHeight = matMulArgs.getMatrixAHeight();
  size_t matrixBWidth = matMulArgs.getMatrixBWidth();
  size_t sharedDim = matMulArgs.getSharedDim();

  size_t blockSize = matMulArgs.getBlockSize();
  size_t numReadThreads = matMulArgs.getNumReadThreads();
  size_t numProdThreads = matMulArgs.getNumMatMulThreads();
  size_t numAccumThreads = (size_t) ceil((double)numProdThreads / 2.0);
  std::string directory = matMulArgs.getDirectory();
  std::string outputDirectory = matMulArgs.getOutputDir();
  bool runSequential = matMulArgs.isRunSequential();
  bool validate = matMulArgs.isValidateResults();

  size_t numGPUs = matMulArgs.getNumGPUs();
  int gpuIds[numGPUs];

  matMulArgs.copyGpuIds(gpuIds);


//  CUcontext *contexts = initCuda(numGPUs, gpuIds);

  std::string runtimeFileStr("runtimes");

  int numRetry = 1;

  std::ofstream runtimeFile(runtimeFileStr, std::ios::app);
  double *matrixA = new double[matrixAHeight * sharedDim];
  double *matrixB = new double[matrixBWidth * sharedDim];
  double *matrixC = new double[matrixAHeight * matrixBWidth];

  initMatrix(matrixA, sharedDim, matrixAHeight, true);
  initMatrix(matrixB, matrixBWidth, sharedDim, true);

  for (int numTry = 0; numTry < numRetry; numTry++) {
    SimpleClock clk;
    SimpleClock endToEnd;

    if (runSequential) {
      endToEnd.start();
      initMatMul(numProdThreads);

      cublasXtHandle_t handle;

      cublasXtCreate(&handle);

      cublasXtDeviceSelect(handle, numGPUs, gpuIds);
      cublasXtSetBlockDim(handle, blockSize);

      clk.start();
      computeSequentialMatMul(matrixA, matrixB, matrixC, (size_t) matrixAHeight, (size_t) sharedDim,
                              (size_t) matrixBWidth, handle);
      clk.stopAndIncrement();

      cublasXtDestroy(handle);

      endToEnd.stopAndIncrement();
    }
    else {
      endToEnd.start();
      initMatMul(1);

      LoadMatrixTask *readAMatTask =
          new LoadMatrixTask(matrixA,
                             numReadThreads,
                             MatrixType::MatrixA,
                             blockSize,
                             sharedDim,
                             matrixAHeight,
                             true);

      LoadMatrixTask *readBMatTask =
          new LoadMatrixTask(matrixB,
                             numReadThreads,
                             MatrixType::MatrixB,
                             blockSize,
                             matrixBWidth,
                             sharedDim,
                             true);

      MatrixMulBlkCudaTask *mmulTask = new MatrixMulBlkCudaTask(gpuIds, numGPUs);
      MatMulAccumTask *accumTask = new MatMulAccumTask(numAccumThreads, true);

      MatMulOutputTask *outputTask = new MatMulOutputTask(matrixC, matrixAHeight, blockSize, true);

      size_t blkHeightMatB = readBMatTask->getNumBlocksRows();
      size_t blkWidthMatB = readBMatTask->getNumBlocksCols();

      size_t blkHeightMatA = readAMatTask->getNumBlocksRows();
      size_t blkWidthMatA = readAMatTask->getNumBlocksCols();

      CudaCopyInTask *cudaCopyInATask = new CudaCopyInTask(gpuIds, numGPUs, MatrixType::MatrixA, blkWidthMatB);
      CudaCopyInTask *cudaCopyInBTask = new CudaCopyInTask(gpuIds, numGPUs, MatrixType::MatrixB, blkHeightMatA);

      CudaCopyOutTask *cudaCopyOutCTask = new CudaCopyOutTask(gpuIds, numGPUs, MatrixType::MatrixC);

      MatMulDistributeRule *distributeRuleMatA = new MatMulDistributeRule(MatrixType::MatrixA);
      MatMulDistributeRule *distributeRuleMatB = new MatMulDistributeRule(MatrixType::MatrixB);

      MatMulLoadRule<htgs::m_data_t<double>> *loadRule = new MatMulLoadRule<htgs::m_data_t<double>>(blkWidthMatA, blkHeightMatA, blkWidthMatB, blkHeightMatB);
      MatMulAccumulateRule<double *> *accumulateRule = new MatMulAccumulateRule<double *>(blkWidthMatB, blkHeightMatA, blkWidthMatA);

      MatMulOutputRule *outputRule = new MatMulOutputRule(blkWidthMatB, blkHeightMatA, blkWidthMatA);

      auto distributeBk = new htgs::Bookkeeper<MatrixRequestData>();
      auto matMulBk = new htgs::Bookkeeper<MatrixBlockData<htgs::m_data_t<double>>>();
      auto matAccumBk = new htgs::Bookkeeper<MatrixBlockData<double *>>();

      auto taskGraph = new htgs::TaskGraphConf<MatrixRequestData, MatrixBlockData<double *>>();

      taskGraph->setGraphConsumerTask(distributeBk);
      taskGraph->addRuleEdge(distributeBk, distributeRuleMatA, readAMatTask);
      taskGraph->addRuleEdge(distributeBk, distributeRuleMatB, readBMatTask);


      taskGraph->addEdge(readAMatTask, cudaCopyInATask);
      taskGraph->addEdge(readBMatTask, cudaCopyInBTask);

      taskGraph->addEdge(cudaCopyInATask, matMulBk);
      taskGraph->addEdge(cudaCopyInBTask, matMulBk);

      taskGraph->addRuleEdge(matMulBk, loadRule, mmulTask);

      taskGraph->addEdge(mmulTask, cudaCopyOutCTask);

      taskGraph->addGraphProducerTask(cudaCopyOutCTask);

      taskGraph->addCudaMemoryManagerEdge(matrixTypeToString(MatrixType::MatrixA) + "Copy",
                                          cudaCopyInATask,
                                          new CudaAllocator(blockSize, blockSize),
                                          blkWidthMatB+1,
                                          htgs::MMType::Static,
                                          gpuIds);
      taskGraph->addCudaMemoryManagerEdge(matrixTypeToString(MatrixType::MatrixB) + "Copy",
                                          cudaCopyInBTask,
                                          new CudaAllocator(blockSize, blockSize),
                                          blkHeightMatA+1,
                                          htgs::MMType::Static,
                                          gpuIds);

      taskGraph->addCudaMemoryManagerEdge(matrixTypeToString(MatrixType::MatrixC),
                                          mmulTask,
                                          new CudaAllocator(blockSize, blockSize),
                                          4,
                                          htgs::MMType::Static,
                                          gpuIds);


      auto mainTaskGraph = new htgs::TaskGraphConf<MatrixRequestData, MatrixRequestData>();


      auto execPipeline = new htgs::ExecutionPipeline<MatrixRequestData, MatrixBlockData<double *>>(numGPUs, taskGraph);
      auto decompositionRule = new MatrixDecompositionRule(numGPUs);

      execPipeline->addInputRule(decompositionRule);

      mainTaskGraph->setGraphConsumerTask(execPipeline);
      mainTaskGraph->addEdge(execPipeline, matAccumBk);


      mainTaskGraph->addRuleEdge(matAccumBk, outputRule, outputTask);
      mainTaskGraph->addRuleEdge(matAccumBk, accumulateRule, accumTask);

      mainTaskGraph->addEdge(accumTask, matAccumBk);

      mainTaskGraph->addGraphProducerTask(outputTask);

//      mainTaskGraph->writeDotToFile("pre-execution.dot");

      htgs::TaskGraphRuntime *runtime = new htgs::TaskGraphRuntime(mainTaskGraph);

      clk.start();

      runtime->executeRuntime();

      for (size_t col = 0; col < blkWidthMatA; col++) {
        for (size_t row = 0; row < blkHeightMatA; row++) {

          MatrixRequestData *matA = new MatrixRequestData(row, col, MatrixType::MatrixA);
          mainTaskGraph->produceData(matA);
        }
      }

      for (size_t row = 0; row < blkHeightMatB; row++) {
        for (size_t col = 0; col < blkWidthMatB; col++) {
          MatrixRequestData *matB = new MatrixRequestData(row, col, MatrixType::MatrixB);
          mainTaskGraph->produceData(matB);

        }
      }

      mainTaskGraph->finishedProducingData();

      while (!mainTaskGraph->isOutputTerminated()) {
        auto data = mainTaskGraph->consumeData();
        if (data != nullptr) {
//          std::cout << data->getRow() << ", " << data->getCol() << std::endl;
        }
      }

      runtime->waitForRuntime();


//      taskGraph->writeDotToFile("profile-graph.dot");
//      mainTaskGraph->writeDotToFile("profile-all-threads-graph.dot", DOTGEN_FLAG_SHOW_ALL_THREADING);
      mainTaskGraph->writeDotToFile("matrix-multiplication-cuda-multigpu.dot", DOTGEN_COLOR_COMP_TIME);

      clk.stopAndIncrement();

      delete runtime;
      endToEnd.stopAndIncrement();
    }

    if (validate) {
      double *matrixCTest = new double[matrixAHeight * matrixBWidth];
      initMatMul(numProdThreads);

      cublasXtHandle_t handle;

      cublasXtCreate(&handle);

      cublasXtDeviceSelect(handle, (int)numGPUs, gpuIds);
      cublasXtSetBlockDim(handle, (int)blockSize);

      computeSequentialMatMul(matrixA, matrixB, matrixCTest, (size_t) matrixAHeight, (size_t) sharedDim,
                              (size_t) matrixBWidth, handle);

      cublasXtDestroy(handle);


      int res = validateResults(matrixC, matrixCTest, matrixAHeight, matrixBWidth);
      if (res != 0) {
        std::cout << "Error validating test failed!" << std::endl;
      }
      else
      {
        std::cout << "Test PASSED" << std::endl;
      }

      delete []matrixCTest;

    }


    double numGflops = (2.0 * matrixAHeight *sharedDim * matrixBWidth) * 1.0e-9d;
    double gflops = numGflops / clk.getAverageTime(TimeVal::SEC);



    std::cout << (runSequential ? "sequential" : "htgs") << ", " << numProdThreads
              << ", accum-threads: " << numAccumThreads << ", width-b: " << matrixBWidth << ", height-a: " << matrixAHeight
              << ", shared-dim: " << sharedDim
              << ", blockSize: " << blockSize
              << ", time:" << clk.getAverageTime(TimeVal::MILLI)
              << ", end-to-end:" << endToEnd.getAverageTime(TimeVal::MILLI)
              << ", gflops: " << gflops
              << std::endl;

    runtimeFile << "MULTIGPU-MM" << (runSequential ? "sequential" : "htgs") << ", " << numProdThreads
                << ", " << numAccumThreads << ", "
                << matrixBWidth << ", " << matrixAHeight
                << ", " << sharedDim << ", " << blockSize << ", " << clk.getAverageTime(TimeVal::MILLI)
                << ", " << endToEnd.getAverageTime(TimeVal::MILLI)
                << std::endl;

  }

  delete[] matrixA;
  delete[] matrixB;
  delete[] matrixC;
}