Exemplo n.º 1
0
void set_default_tensor_type(const at::Type& type) {
  if (!at::isFloatingType(type.scalarType())) {
    throw TypeError("only floating-point types are supported as the default type");
  }
  if (!type.is_variable_or_undefined()) {
    throw TypeError("only variable types are supported");
  }
  if (type.is_sparse()) {
    throw TypeError("only dense types are supported as the default type");
  }

  // get the storage first, so if it doesn't exist we don't change the default tensor type
  THPObjectPtr storage = get_storage_obj(type);
  default_tensor_type = const_cast<Type*>(&type);

  auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
  if (!torch_module) throw python_error();

  if (PyObject_SetAttrString(torch_module.get(), "Storage", storage) != 0) {
    // technically, we should undo the change of default tensor type.
    throw python_error();
  }
}