/*
 * Request the CUTPI counter values and write it to the given VampirTrace
 * stream with the given timestamps.
 *
 * @param vtcuptiEvtCtx pointer to the VampirTrace CUPTI events context
 * @param strmid the stream id for the counter values
 * @param time the VampirTrace timestamps
 */
void vt_cuptievt_writeCounter(vt_cupti_events_t *vtcuptiEvtCtx, uint32_t strmid,
                              uint64_t *time)
{
  CUptiResult cuptiErr = CUPTI_SUCCESS;
  vt_cupti_evtgrp_t *vtcuptiGrp = NULL;

  size_t bufferSizeBytes;
  size_t arraySizeBytes;
  size_t numCountersRead;

  if(vtcuptiEvtCtx == NULL){
    VT_CHECK_THREAD;
    vtcuptiEvtCtx = vt_cuptievt_getOrCreateCurrentCtx(VT_MY_THREAD)->events;
    if(vtcuptiEvtCtx == NULL) return;
  }

  vtcuptiGrp = vtcuptiEvtCtx->vtGrpList;
  while(vtcuptiGrp != NULL){
    /* read events only, if the event group is enabled */
    if(vtcuptiGrp->enabled){

      bufferSizeBytes = vtcuptiGrp->evtNum * sizeof(uint64_t);
      arraySizeBytes = vtcuptiGrp->evtNum * sizeof(CUpti_EventID);

      /* read events */
      cuptiErr = cuptiEventGroupReadAllEvents(vtcuptiGrp->evtGrp,
                                              CUPTI_EVENT_READ_FLAG_NONE,
                                              &bufferSizeBytes, vtcuptiEvtCtx->counterData,
                                              &arraySizeBytes, vtcuptiEvtCtx->cuptiEvtIDs,
                                              &numCountersRead);
      VT_CUPTI_CALL(cuptiErr, "cuptiEventGroupReadAllEvents");
      
      if(vtcuptiGrp->evtNum != numCountersRead){
        vt_error_msg("[CUPTI Events] %d counter reads, %d metrics specified in "
                     "VT_CUPTI_METRICS!", numCountersRead, vtcuptiGrp->evtNum);
      }

      /* For all events of the event group: map added event IDs to just read event
       * IDs, as the order may not be the same. For small numbers of counter reads
       * this simple mapping should be fast enough.
       */
      {
        size_t j;

        for(j = 0; j < numCountersRead; j++){
          size_t i;
          for(i = 0; i < vtcuptiGrp->evtNum; i++){
            if(vtcuptiEvtCtx->cuptiEvtIDs[j] == *(vtcuptiGrp->cuptiEvtIDs+i)){
              /* write the counter value as VampirTrace counter */
              vt_count(strmid, time, *(vtcuptiGrp->vtCIDs+i), vtcuptiEvtCtx->counterData[i]);
            }
          }
        }
      }

    }

    vtcuptiGrp = vtcuptiGrp->next;
  }
  
}
static void cupti_callback_launch_kernel(cupti_user_t *user, CUpti_CallbackData *cbdata)
{

    // Find associated counter data
    pthread_mutex_lock(&mutex);
    struct context_counter_data *counter_data;
    for (counter_data = allCounterData; counter_data; counter_data = counter_data->next) {
        if (counter_data->context == cbdata->context) {
            break;
        }
    }
    pthread_mutex_unlock(&mutex);
    if (!counter_data) {
        if (cbdata->callbackSite == CUPTI_API_ENTER) {
            fprintf(stderr, "CUPTI warning: Could not find context for kernel start!\n");
            // Simply generate it. Use user data as a cheap way to
            // prevent an infinite loop.
            if (user) {
                CUpti_ResourceData rdata;
                memset(&rdata, 0, sizeof(rdata));
                rdata.context = cbdata->context;
                cupti_callback_context_created(user, &rdata);
                cupti_callback_launch_kernel(NULL, cbdata);
            }
        }
        return;
    }
    CUpti_EventGroupSets *eventGroupPasses = counter_data->eventGroupSets;

    if (cbdata->callbackSite == CUPTI_API_ENTER) {
        //cudaDeviceSynchronize();

        // Set collection mode. Kernel mode is the only one that is
        // guaranteed to work, even if it forces us to sum up metrics
        // manually.
        CUPTI_ASSERT(cuptiSetEventCollectionMode(cbdata->context,
                                                 CUPTI_EVENT_COLLECTION_MODE_KERNEL));

        // Enable the counters!
        int i;
        for (i = 0; i < eventGroupPasses->sets->numEventGroups; i++) {
            uint32_t all = 1;
            CUPTI_ASSERT(cuptiEventGroupSetAttribute(eventGroupPasses->sets->eventGroups[i],
                                                     CUPTI_EVENT_GROUP_ATTR_PROFILE_ALL_DOMAIN_INSTANCES,
                                                     sizeof(all), &all));
            CUPTI_ASSERT(cuptiEventGroupEnable(eventGroupPasses->sets->eventGroups[i]));
        }
    }

    else if (cbdata->callbackSite == CUPTI_API_EXIT) {
        CUdevice device = get_device_from_ctx(cbdata->context);

        // Find out how many events we have in total. Note that
        // cuptiMetricGetNumEvents wouldn't help us here, as we are
        // collecting multiple metrics, which *might* have overlapping
        // events.
        uint32_t numEvents = 0; int i;
        for (i = 0; i < eventGroupPasses->sets->numEventGroups; i++) {
            uint32_t num = 0;
            size_t numSize = sizeof(num);
            CUPTI_ASSERT(cuptiEventGroupGetAttribute(eventGroupPasses->sets->eventGroups[i],
                                                     CUPTI_EVENT_GROUP_ATTR_NUM_EVENTS,
                                                     &numSize, &num));
            numEvents += num;
        }

        // Allocate arrays for event IDs & values
        size_t eventIdsSize = sizeof(CUpti_EventID) * numEvents;
        CUpti_EventID *eventIds = (CUpti_EventID *)alloca(eventIdsSize);
        size_t eventValuesSize = sizeof(uint64_t) * numEvents;
        uint64_t *eventValues = (uint64_t *)alloca(eventValuesSize);
        memset(eventValues, 0, sizeof(uint64_t) * numEvents);

        // Now read all events, per group
        int eventIx = 0;
        for (i = 0; i < eventGroupPasses->sets->numEventGroups; i++) {
            CUpti_EventGroup eventGroup = eventGroupPasses->sets->eventGroups[i];

            // Get event IDs
            uint32_t num = 0;
            size_t numSize = sizeof(num);
            CUPTI_ASSERT(cuptiEventGroupGetAttribute(eventGroup,
                                                     CUPTI_EVENT_GROUP_ATTR_NUM_EVENTS,
                                                     &numSize, &num));

            // Get how many domain instances were actually counting
            uint32_t domInstNum = 0;
            size_t domInstNumSize = sizeof(domInstNum);
            CUPTI_ASSERT(cuptiEventGroupGetAttribute(eventGroup,
                                                     CUPTI_EVENT_GROUP_ATTR_INSTANCE_COUNT,
                                                     &domInstNumSize, &domInstNum));

            // Get counter values from all instances
            size_t idsSize = sizeof(CUpti_EventID) * num;
            size_t valsSize = sizeof(uint64_t) * num * domInstNum;
            uint64_t *vals = (uint64_t *)alloca(valsSize);
            size_t numRead = 0;
            CUPTI_ASSERT(cuptiEventGroupReadAllEvents(eventGroup,
                                                      CUPTI_EVENT_READ_FLAG_NONE,
                                                      &valsSize,
                                                      vals,
                                                      &idsSize,
                                                      eventIds + eventIx,
                                                      &numRead));
            if (numRead != num) {
                fprintf(stderr, "CUPTI warning: ReadAllEvents returned unexpected number of values (expected %u, got %u)!\n", (unsigned)num, (unsigned)numRead);
            }

            // For normalisation we need the *total* number of domain
            // instances (not only the ones that were available for counting)
            CUpti_EventDomainID domainId = 0;
            size_t domainIdSize = sizeof(domainId);
            CUPTI_ASSERT(cuptiEventGroupGetAttribute(eventGroup,
                                                     CUPTI_EVENT_GROUP_ATTR_EVENT_DOMAIN_ID,
                                                     &domainIdSize, &domainId));
            uint32_t totalDomInstNum = 0;
            size_t totalDomInstNumSize = sizeof(totalDomInstNum);
            CUPTI_ASSERT(cuptiDeviceGetEventDomainAttribute(device, domainId,
                                                            CUPTI_EVENT_DOMAIN_ATTR_TOTAL_INSTANCE_COUNT,
                                                            &totalDomInstNumSize,
                                                            &totalDomInstNum));

            // Determine true counter values
            int j;
            for (j = 0; j < numRead; j++) {

                // First, sum up across instances
                uint64_t val = 0; int k;
                for (k = 0; k < domInstNum; k++) {
                    val += vals[j+k*num];
                }

                // Then normalise and add to proper event count
                eventValues[eventIx + j] = (val * totalDomInstNum) / domInstNum;
            }

            // Progress!
            eventIx += num;
        }

        // Now calculate metrics.
        for (i = 0; i < metricCount; i++) {

            // This only works if the metric does not depend on kernel
            // time (because we set it to zero here - use
            // cupti_activity facilities to measure kernel time
            // separately).
            CUpti_MetricValue metric;
            CUPTI_ASSERT(cuptiMetricGetValue(device, counter_data->metricIds[i],
                                             eventIdsSize, eventIds,
                                             eventValuesSize, eventValues,
                                             0, &metric));

            // Sum up metrics. Note that this might not actually make
            // sense for all of them, we warn about that before.
            switch (counter_data->metricKinds[i]) {
            case CUPTI_METRIC_VALUE_KIND_DOUBLE:
                metrics[i] += metric.metricValueDouble;
                break;
            case CUPTI_METRIC_VALUE_KIND_UINT64:
                metrics[i] += metric.metricValueUint64;
                break;
            case CUPTI_METRIC_VALUE_KIND_INT64:
                metrics[i] += metric.metricValueInt64;
                break;
            case CUPTI_METRIC_VALUE_KIND_PERCENT:
                metrics[i] += metric.metricValuePercent;
                break;
            case CUPTI_METRIC_VALUE_KIND_THROUGHPUT:
                metrics[i] += metric.metricValueThroughput;
                break;
            case CUPTI_METRIC_VALUE_KIND_UTILIZATION_LEVEL:
                metrics[i] += metric.metricValueUtilizationLevel;
                break;
            }
        }
    }
}