void* THCState_getCurrentDeviceScratchSpace(THCState* state) { int device = -1; THCudaCheck(cudaGetDevice(&device)); int stream = THCState_getCurrentStreamIndex(state); return THCState_getDeviceScratchSpace(state, device, stream); }
void THCState_reserveStreams(THCState* state, int numStreams, int nonBlocking) { if (numStreams <= state->numUserStreams) { return; } int prevDev = -1; THCudaCheck(cudaGetDevice(&prevDev)); /* Otherwise, we have to allocate a new set of streams and stream data */ for (int dev = 0; dev < state->numDevices; ++dev) { THCudaCheck(cudaSetDevice(dev)); /* +1 for the default stream as well */ cudaStream_t* newStreams = (cudaStream_t*) malloc((numStreams + 1) * sizeof(cudaStream_t)); void** newScratchSpace = (void**) malloc((numStreams + 1) * sizeof(void*)); /* Copy over old stream data (0 is default stream, 1 ... numUserStreams are rest) */ for (int stream = 0; stream <= state->numUserStreams; ++stream) { newStreams[stream] = THCState_getDeviceStream(state, dev, stream); newScratchSpace[stream] = THCState_getDeviceScratchSpace(state, dev, stream); } /* Allocate new stream resources */ size_t scratchSpaceSize = THCState_getDeviceScratchSpaceSize(state, dev); unsigned int flags = nonBlocking ? cudaStreamNonBlocking : cudaStreamDefault; for (int stream = state->numUserStreams + 1; stream <= numStreams; ++stream) { newStreams[stream] = NULL; THCudaCheck(cudaStreamCreateWithFlags(newStreams + stream, flags)); newScratchSpace[stream] = NULL; THCudaCheck(THCudaMalloc(state, &newScratchSpace[stream], scratchSpaceSize)); } THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, dev); free(res->streams); res->streams = newStreams; free(res->devScratchSpacePerStream); res->devScratchSpacePerStream = newScratchSpace; } state->numUserStreams = numStreams; THCudaCheck(cudaSetDevice(prevDev)); }
void THCudaShutdown(THCState* state) { THCRandom_shutdown(state); free(state->rngState); free(state->cudaHostAllocator); free(state->deviceProperties); int deviceCount = 0; int prevDev = -1; THCudaCheck(cudaGetDevice(&prevDev)); THCudaCheck(cudaGetDeviceCount(&deviceCount)); /* cleanup p2p access state */ for (int dev = 0; dev < deviceCount; ++dev) { free(state->p2pAccessEnabled[dev]); } free(state->p2pAccessEnabled); /* cleanup per-device state */ for (int dev = 0; dev < deviceCount; ++dev) { THCudaCheck(cudaSetDevice(dev)); /* Free Torch-defined streams (0 is the default stream) */ for (int stream = 1; stream <= state->numUserStreams; ++stream) { THCudaCheck(cudaStreamDestroy( THCState_getDeviceStream(state, dev, stream))); } /* Free Torch-defined handles (0 is NULL for consistency with streams API) */ for (int handle = 1; handle <= state->numUserBlasHandles; ++handle) { THCublasCheck(cublasDestroy( THCState_getDeviceBlasHandle(state, dev, handle))); } /* Free per-stream scratch space; starts at 0 because there is space for the default stream as well*/ for (int stream = 0; stream <= state->numUserStreams; ++stream) { THCudaCheck(THCudaFree(state, THCState_getDeviceScratchSpace(state, dev, stream))); } free(state->resourcesPerDevice[dev].streams); free(state->resourcesPerDevice[dev].blasHandles); free(state->resourcesPerDevice[dev].devScratchSpacePerStream); } free(state->resourcesPerDevice); state->cudaDeviceAllocator.shutdown(state->cudaDeviceAllocator.state); THCThreadLocal_free(state->currentPerDeviceStream); THCThreadLocal_free(state->currentPerDeviceBlasHandle); THCudaCheck(cudaSetDevice(prevDev)); }