Пример #1
0
void initialize_python_bindings() {
  // Initialize the at::Type* pointers, name, and properties of the PyTensorType
  // vector. After this call, the vector must not be resized.
  initialize_aten_types(tensor_types);

  // Initialize the Python metaclass for the torch.FloatTensor, etc. types.
  // The metaclass handles __instancecheck__ checks and binds the dtype property
  // on the type objects.
  py_initialize_metaclass(metaclass);

  // Get the tp_dict of the Variable class. We copy function definitions
  // onto each Tensor type object so that they can be accessed via e.g.
  // `torch.FloatTensor.add`.
  auto tensor_dict = get_tensor_dict();

  // Initialize each Python type object torch.FloatTensor, torch.DoubleTensor, etc.
  for (auto& tensor_type : tensor_types) {
    py_initialize_tensor_type(tensor_type.py_type, tensor_type.name, tensor_dict.get());
  }

  // Add the type objects to their corresponding modules. e.g. torch.FloatTensor
  // is added to the `torch` module as `FloatTensor`. Also add all the type
  // objects to the set torch._tensor_classes.
  py_bind_tensor_types(tensor_types);

  // Use torch.float32 as the default tensor type
  set_default_tensor_type(torch::CPU(kFloat));
}
Пример #2
0
void py_set_default_tensor_type(PyObject* obj) {
  PyTensorType *type;
  if (PyTensorType_Check(obj)) {
    type = (PyTensorType*)obj;
  } else {
    throw TypeError("invalid type object");
  }
  if (!type->aten_type) {
    throw unavailable_type(*type);
  }
  set_default_tensor_type(*type->aten_type);
}
Пример #3
0
void py_set_default_dtype(PyObject* obj) {
  PyTensorType *type;
  if (THPDtype_Check(obj)) {
    auto &current_default = get_default_tensor_type();
    type = &get_tensor_type((THPDtype*)obj, torch::getLayout(current_default.backend()),
                            torch::getDeviceType(current_default) == DeviceType::CUDA);
  } else {
    throw TypeError("invalid type object");
  }
  if (!type->aten_type) {
    throw unavailable_type(*type);
  }
  set_default_tensor_type(*type->aten_type);
}