void initialize_python_bindings() { // Initialize the at::Type* pointers, name, and properties of the PyTensorType // vector. After this call, the vector must not be resized. initialize_aten_types(tensor_types); // Initialize the Python metaclass for the torch.FloatTensor, etc. types. // The metaclass handles __instancecheck__ checks and binds the dtype property // on the type objects. py_initialize_metaclass(metaclass); // Get the tp_dict of the Variable class. We copy function definitions // onto each Tensor type object so that they can be accessed via e.g. // `torch.FloatTensor.add`. auto tensor_dict = get_tensor_dict(); // Initialize each Python type object torch.FloatTensor, torch.DoubleTensor, etc. for (auto& tensor_type : tensor_types) { py_initialize_tensor_type(tensor_type.py_type, tensor_type.name, tensor_dict.get()); } // Add the type objects to their corresponding modules. e.g. torch.FloatTensor // is added to the `torch` module as `FloatTensor`. Also add all the type // objects to the set torch._tensor_classes. py_bind_tensor_types(tensor_types); // Use torch.float32 as the default tensor type set_default_tensor_type(torch::CPU(kFloat)); }
void py_set_default_tensor_type(PyObject* obj) { PyTensorType *type; if (PyTensorType_Check(obj)) { type = (PyTensorType*)obj; } else { throw TypeError("invalid type object"); } if (!type->aten_type) { throw unavailable_type(*type); } set_default_tensor_type(*type->aten_type); }
void py_set_default_dtype(PyObject* obj) { PyTensorType *type; if (THPDtype_Check(obj)) { auto ¤t_default = get_default_tensor_type(); type = &get_tensor_type((THPDtype*)obj, torch::getLayout(current_default.backend()), torch::getDeviceType(current_default) == DeviceType::CUDA); } else { throw TypeError("invalid type object"); } if (!type->aten_type) { throw unavailable_type(*type); } set_default_tensor_type(*type->aten_type); }