Beispiel #1
0
bool FunctionParameter::check(PyObject* obj) {
  switch (type_) {
    case ParameterType::TENSOR: {
      return THPVariable_Check(obj);
    }
    case ParameterType::SCALAR:
    case ParameterType::DOUBLE: {
      // NOTE: we don't currently accept most NumPy types as Scalars. np.float64
      // is okay because it's a subclass of PyFloat. We may want to change this
      // in the future.
      if (THPUtils_checkDouble(obj)) {
        return true;
      }
      if (THPVariable_Check(obj)) {
        auto& var = ((THPVariable*)obj)->cdata;
        return !var.requires_grad() && var.dim() == 0;
      }
      return false;
    }
    case ParameterType::INT64: {
      if (THPUtils_checkLong(obj)) {
        return true;
      }
      if (THPVariable_Check(obj)) {
        auto& var = ((THPVariable*)obj)->cdata;
        return at::isIntegralType(var.type().scalarType()) && !var.requires_grad() && var.dim() == 0;
      }
      return false;
    }
    case ParameterType::TENSOR_LIST: return PyTuple_Check(obj) || PyList_Check(obj);
    case ParameterType::INT_LIST: {
      if (PyTuple_Check(obj) || PyList_Check(obj)) {
        return true;
      }
      // if a size is specified (e.g. IntList[2]) we also allow passing a single int
      return size > 0 && THPUtils_checkLong(obj);
    }
    case ParameterType::GENERATOR: return THPGenerator_Check(obj);
    case ParameterType::BOOL: return PyBool_Check(obj);
    case ParameterType::STORAGE: return isStorage(obj);
    case ParameterType::PYOBJECT: return true;
    case ParameterType::SCALARTYPE: return THPDtype_Check(obj);
    case ParameterType::LAYOUT: return THPLayout_Check(obj);
    case ParameterType::DEVICE:
      return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj);
    case ParameterType::STRING: return THPUtils_checkString(obj);
    default: throw std::runtime_error("unknown parameter type");
  }
}
Beispiel #2
0
[[noreturn]]
static void extra_kwargs(FunctionSignature& signature, PyObject* kwargs, ssize_t num_pos_args) {
  PyObject *key, *value;
  ssize_t pos = 0;

  while (PyDict_Next(kwargs, &pos, &key, &value)) {
    if (!THPUtils_checkString(key)) {
      throw TypeError("keywords must be strings");
    }

    auto param_idx = find_param(signature, key);
    if (param_idx < 0) {
      throw TypeError("%s() got an unexpected keyword argument '%s'",
          signature.name.c_str(), THPUtils_unpackString(key).c_str());
    }

    if (param_idx < num_pos_args) {
      throw TypeError("%s() got multiple values for argument '%s'",
          signature.name.c_str(), THPUtils_unpackString(key).c_str());
    }
  }

  // this should never be hit
  throw TypeError("invalid keyword arguments");
}
Beispiel #3
0
static Tensor new_from_sequence(ScalarType scalarType, PyObject* data) {
  if (!PySequence_Check(data)) {
    throw TypeError("new(): data must be a sequence (got %s)", Py_TYPE(data)->tp_name);
  }
  if (THPUtils_checkString(data)) {
    throw TypeError("new(): invalid data type '%s'", Py_TYPE(data)->tp_name);
  }
#ifdef WITH_NUMPY
  if (PyArray_Check(data)) {
    return autograd::make_variable(tensor_from_numpy(data), false);
  }
#endif

  auto sizes = compute_sizes(data);
  auto tensor = autograd::make_variable(CPU(scalarType).tensor(sizes), false);
  recursive_store(
      (char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0,
      scalarType, tensor.type().elementSizeInBytes(), data);
  return tensor;
}
Beispiel #4
0
static PyObject * THPModule_initNames(PyObject *self, PyObject *arg)
{
  static std::vector<std::string> names;

  THPObjectPtr types(PySequence_Fast(arg, "expected a sequence"));
  if (!types) return NULL;

  int num_classes = PySequence_Fast_GET_SIZE(types.get());
  names.reserve(names.size() + num_classes);
  for (int i = 0; i < num_classes; i++) {
    PyObject* obj = PySequence_Fast_GET_ITEM(types.get(), i);
    THPUtils_assert(PyType_Check(obj), "expected a PyTypeObject");
    PyTypeObject* type = (PyTypeObject*)obj;

    THPObjectPtr module_name(PyObject_GetAttrString(obj, "__module__"));
    if (!module_name) return NULL;
    THPUtils_assert(THPUtils_checkString(module_name.get()),
        "expected __module__ to be a string");
    std::string name = THPUtils_unpackString(module_name.get());
    names.push_back(name + "." + type->tp_name);
    type->tp_name = names.back().c_str();
  }
  Py_RETURN_NONE;
}