void THCState_setStreamForCurrentDevice(THCState *state, int stream) { if (state->currentPerDeviceStream != stream) { int device = -1; THCudaCheck(cudaGetDevice(&device)); THCState_setStream(state, device, stream); } }
static int cutorch_setDevice(lua_State *L) { THCState *state = cutorch_getstate(L); int device = (int)luaL_checknumber(L, 1)-1; THCudaCheck(cudaSetDevice(device)); THCRandom_setGenerator(state, device); /* The stream is per device, so update the stream as well */ THCState_setStream(state, device, THCState_getCurrentStreamIndex(state)); THCState_setBlasHandle(state, device, THCState_getCurrentBlasHandleIndex(state)); return 0; }
void THCState_setDevice(THCState *state, int device) { int curDev; THCudaCheck(cudaGetDevice(&curDev)); if (device != curDev) { THCudaCheck(cudaSetDevice(device)); THCRandom_setGenerator(state, device); THCudaBlas_setHandle(state, device); /* The stream is per device, so update the stream as well */ THCState_setStream(state, device, THCState_getCurrentStreamIndex(state)); } }