Esempio n. 1
0
void py_set_default_dtype(PyObject* obj) {
  PyTensorType *type;
  if (THPDtype_Check(obj)) {
    auto &current_default = get_default_tensor_type();
    type = &get_tensor_type((THPDtype*)obj, torch::getLayout(current_default.backend()),
                            torch::getDeviceType(current_default) == DeviceType::CUDA);
  } else {
    throw TypeError("invalid type object");
  }
  if (!type->aten_type) {
    throw unavailable_type(*type);
  }
  set_default_tensor_type(*type->aten_type);
}
Esempio n. 2
0
bool FunctionParameter::check(PyObject* obj) {
  switch (type_) {
    case ParameterType::TENSOR: {
      return THPVariable_Check(obj);
    }
    case ParameterType::SCALAR:
    case ParameterType::DOUBLE: {
      // NOTE: we don't currently accept most NumPy types as Scalars. np.float64
      // is okay because it's a subclass of PyFloat. We may want to change this
      // in the future.
      if (THPUtils_checkDouble(obj)) {
        return true;
      }
      if (THPVariable_Check(obj)) {
        auto& var = ((THPVariable*)obj)->cdata;
        return !var.requires_grad() && var.dim() == 0;
      }
      return false;
    }
    case ParameterType::INT64: {
      if (THPUtils_checkLong(obj)) {
        return true;
      }
      if (THPVariable_Check(obj)) {
        auto& var = ((THPVariable*)obj)->cdata;
        return at::isIntegralType(var.type().scalarType()) && !var.requires_grad() && var.dim() == 0;
      }
      return false;
    }
    case ParameterType::TENSOR_LIST: return PyTuple_Check(obj) || PyList_Check(obj);
    case ParameterType::INT_LIST: {
      if (PyTuple_Check(obj) || PyList_Check(obj)) {
        return true;
      }
      // if a size is specified (e.g. IntList[2]) we also allow passing a single int
      return size > 0 && THPUtils_checkLong(obj);
    }
    case ParameterType::GENERATOR: return THPGenerator_Check(obj);
    case ParameterType::BOOL: return PyBool_Check(obj);
    case ParameterType::STORAGE: return isStorage(obj);
    case ParameterType::PYOBJECT: return true;
    case ParameterType::SCALARTYPE: return THPDtype_Check(obj);
    case ParameterType::LAYOUT: return THPLayout_Check(obj);
    case ParameterType::DEVICE:
      return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj);
    case ParameterType::STRING: return THPUtils_checkString(obj);
    default: throw std::runtime_error("unknown parameter type");
  }
}