void py_set_default_dtype(PyObject* obj) { PyTensorType *type; if (THPDtype_Check(obj)) { auto ¤t_default = get_default_tensor_type(); type = &get_tensor_type((THPDtype*)obj, torch::getLayout(current_default.backend()), torch::getDeviceType(current_default) == DeviceType::CUDA); } else { throw TypeError("invalid type object"); } if (!type->aten_type) { throw unavailable_type(*type); } set_default_tensor_type(*type->aten_type); }
bool FunctionParameter::check(PyObject* obj) { switch (type_) { case ParameterType::TENSOR: { return THPVariable_Check(obj); } case ParameterType::SCALAR: case ParameterType::DOUBLE: { // NOTE: we don't currently accept most NumPy types as Scalars. np.float64 // is okay because it's a subclass of PyFloat. We may want to change this // in the future. if (THPUtils_checkDouble(obj)) { return true; } if (THPVariable_Check(obj)) { auto& var = ((THPVariable*)obj)->cdata; return !var.requires_grad() && var.dim() == 0; } return false; } case ParameterType::INT64: { if (THPUtils_checkLong(obj)) { return true; } if (THPVariable_Check(obj)) { auto& var = ((THPVariable*)obj)->cdata; return at::isIntegralType(var.type().scalarType()) && !var.requires_grad() && var.dim() == 0; } return false; } case ParameterType::TENSOR_LIST: return PyTuple_Check(obj) || PyList_Check(obj); case ParameterType::INT_LIST: { if (PyTuple_Check(obj) || PyList_Check(obj)) { return true; } // if a size is specified (e.g. IntList[2]) we also allow passing a single int return size > 0 && THPUtils_checkLong(obj); } case ParameterType::GENERATOR: return THPGenerator_Check(obj); case ParameterType::BOOL: return PyBool_Check(obj); case ParameterType::STORAGE: return isStorage(obj); case ParameterType::PYOBJECT: return true; case ParameterType::SCALARTYPE: return THPDtype_Check(obj); case ParameterType::LAYOUT: return THPLayout_Check(obj); case ParameterType::DEVICE: return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj); case ParameterType::STRING: return THPUtils_checkString(obj); default: throw std::runtime_error("unknown parameter type"); } }