示例#1
0
文件: THCGeneral.c 项目: noa/cutorch
void THCudaInit(THCState* state)
{
  int count = 0;
  THCudaCheck(cudaGetDeviceCount(&count));

  int device = 0;
  THCudaCheck(cudaGetDevice(&device));

  state->rngState = (THCRNGState*)malloc(sizeof(THCRNGState));
  THCRandom_init(state, count, device);

  state->blasState = (THCBlasState*)malloc(sizeof(THCBlasState));
  THCudaBlas_init(state, count, device);

  state->numDevices = count;
  state->deviceProperties =
    (struct cudaDeviceProp*)malloc(count * sizeof(struct cudaDeviceProp));

  THCState_setDeviceMode(state, THCStateDeviceModeManual);

  state->numUserStreams = 0;
  state->streamsPerDevice =
    (cudaStream_t**)malloc(count * sizeof(cudaStream_t*));

  /* Enable P2P access between all pairs, if possible */
  THCudaEnablePeerToPeerAccess(state);

  for (int i = 0; i < count; ++i)
  {
    THCudaCheck(cudaSetDevice(i));
    THCudaCheck(cudaGetDeviceProperties(&state->deviceProperties[i], i));
    /* Stream index 0 will be the default stream for convenience; by
       default no user streams are reserved */
    state->streamsPerDevice[i] =
      (cudaStream_t*)malloc(sizeof(cudaStream_t));
    state->streamsPerDevice[i][0] = NULL;
  }

  /* Restore to previous device */
  THCudaCheck(cudaSetDevice(device));

  /* Start in the default stream on the current device */
  state->currentPerDeviceStream = 0;
  state->currentStream = NULL;
}
示例#2
0
void THCudaInit(THCState* state)
{
  int count = 0;
  THCudaCheck(cudaGetDeviceCount(&count));

  int device = 0;
  THCudaCheck(cudaGetDevice(&device));

  state->rngState = (THCRNGState*)malloc(sizeof(THCRNGState));
  THCRandom_init(state, count, device);

  THCAllocator_init(state);

  state->numDevices = count;
  state->deviceProperties =
    (struct cudaDeviceProp*)malloc(count * sizeof(struct cudaDeviceProp));

  state->numUserStreams = 0;
  state->numUserBlasHandles = 0;

  /* Enable P2P access between all pairs, if possible */
  THCudaEnablePeerToPeerAccess(state);

  state->resourcesPerDevice = (THCCudaResourcesPerDevice*)
    malloc(count * sizeof(THCCudaResourcesPerDevice));
  for (int i = 0; i < count; ++i) {
    THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, i);

    THCudaCheck(cudaSetDevice(i));
    THCudaCheck(cudaGetDeviceProperties(&state->deviceProperties[i], i));
    /* Stream index 0 will be the default stream for convenience; by
       default no user streams are reserved */
    res->streams = NULL;
    res->blasHandles = NULL;

    /* The scratch space that we want to have available per each device is
       based on the number of SMs available per device */
    int numSM = state->deviceProperties[i].multiProcessorCount;
    size_t sizePerStream = numSM * GLOBAL_SCRATCH_SPACE_PER_SM_STREAM;
    res->scratchSpacePerStream = sizePerStream;

    /* Allocate scratch space for each stream */
    res->devScratchSpacePerStream = (void**) malloc(sizeof(void*));
    THCudaCheck(THCudaMalloc(state, &res->devScratchSpacePerStream[0],
                           sizePerStream));
  }

  /* Restore to previous device */
  THCudaCheck(cudaSetDevice(device));

  /* Start in the default stream on the current device */
  state->currentPerDeviceStream = 0;
  state->currentStream = NULL;

  /* There is no such thing as a default cublas handle.
     To maintain consistency with streams API, handle 0 is always NULL and we
     start counting at 1
   */
  THCState_reserveBlasHandles(state, 1);
  state->currentPerDeviceBlasHandle = 1;
  state->currentBlasHandle = THCState_getDeviceBlasHandle(state, device, 1);

  state->cutorchGCFunction = NULL;
  state->cutorchGCData = NULL;
  state->heapSoftmax = 3e8; // 300MB, adjusted upward dynamically
  state->heapDelta = 0;
}
示例#3
0
void THCudaInit(THCState* state)
{
  if (!state->cudaDeviceAllocator.malloc) {
    THCState_initDefaultDeviceAllocator(&state->cudaDeviceAllocator);
  }

  int numDevices = 0;
  THCudaCheck(cudaGetDeviceCount(&numDevices));
  state->numDevices = numDevices;

  int device = 0;
  THCudaCheck(cudaGetDevice(&device));

  /* Start in the default stream on the current device */
  state->currentPerDeviceStream = THCThreadLocal_alloc();
  state->currentPerDeviceBlasHandle = THCThreadLocal_alloc();

  state->resourcesPerDevice = (THCCudaResourcesPerDevice*)
    malloc(numDevices * sizeof(THCCudaResourcesPerDevice));
  memset(state->resourcesPerDevice, 0, numDevices * sizeof(THCCudaResourcesPerDevice));

  state->deviceProperties =
    (struct cudaDeviceProp*)malloc(numDevices * sizeof(struct cudaDeviceProp));

  state->rngState = (THCRNGState*)malloc(sizeof(THCRNGState));
  THCRandom_init(state, numDevices, device);

  state->cudaHostAllocator = (THAllocator*)malloc(sizeof(THAllocator));
  THCAllocator_init(state->cudaHostAllocator);

  /* Enable P2P access between all pairs, if possible */
  THCudaEnablePeerToPeerAccess(state);

  for (int i = 0; i < numDevices; ++i) {
    THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, i);
    THCudaCheck(cudaSetDevice(i));
    THCudaCheck(cudaGetDeviceProperties(&state->deviceProperties[i], i));

    /* The scratch space that we want to have available per each device is
       based on the number of SMs available per device */
    int numSM = state->deviceProperties[i].multiProcessorCount;
    size_t sizePerStream = numSM * GLOBAL_SCRATCH_SPACE_PER_SM_STREAM;
    res->scratchSpacePerStream = sizePerStream;

    /* Allocate scratch space for each stream */
    res->devScratchSpacePerStream = (void**) malloc(sizeof(void*));
    THCudaCheck(THCudaMalloc(state, &res->devScratchSpacePerStream[0],
                           sizePerStream));
  }

  /* Restore to previous device */
  THCudaCheck(cudaSetDevice(device));

  /* There is no such thing as a default cublas handle.
     To maintain consistency with streams API, handle 0 is always NULL and we
     start counting at 1. If currentPerDeviceBlasHandle is 0 (the default
     thread-local value), then we assume it means 1.
   */
  THCState_reserveBlasHandles(state, 1);

  state->heapSoftmax = 3e8; // 300MB, adjusted upward dynamically
  state->heapDelta = 0;
}