Beispiel #1
0
/*
   Usage:
   n = cutorch.getBlasHandle()
   Returns the current blasHandle for all devices in use (as previously
   set via cutorch.setBlasHandle(n).
*/
static int cutorch_getBlasHandle(lua_State *L)
{
  THCState *state = cutorch_getstate(L);
  lua_pushnumber(L, THCState_getCurrentBlasHandleIndex(state));

  return 1;
}
Beispiel #2
0
static int cutorch_setDevice(lua_State *L)
{
  THCState *state = cutorch_getstate(L);
  int device = (int)luaL_checknumber(L, 1)-1;
  THCudaCheck(cudaSetDevice(device));
  THCRandom_setGenerator(state, device);

  /* The stream is per device, so update the stream as well */
  THCState_setStream(state, device, THCState_getCurrentStreamIndex(state));
  THCState_setBlasHandle(state, device, THCState_getCurrentBlasHandleIndex(state));

  return 0;
}
Beispiel #3
0
cublasHandle_t THCState_getCurrentBlasHandle(THCState *state)
{
  /* This is called at the point of kernel execution.
     For some debugging code or improperly instrumented kernels,
     `state` is null */
  if (state) {
    int device;
    THCudaCheck(cudaGetDevice(&device));

    int handle = THCState_getCurrentBlasHandleIndex(state);
    return THCState_getDeviceBlasHandle(state, device, handle);
  }
  THError("THCState and blasHandles must be set as there is no default blasHandle");
  return NULL;
}