示例#1
0
void check_input_variables(const char* name, const variable_list& inputs, int args, int required_args) {
  if (required_args == -1) {
    required_args = args;
  }
  if (inputs.size() != (size_t)args) {
    std::stringstream ss;
    ss << name << ": expected " << args << " arguments (got " << inputs.size();
    ss << ")";
    throw std::runtime_error(ss.str());
  }
  for (int i = 0; i < required_args; ++i) {
    if (!inputs[i].defined()) {
      std::stringstream ss;
      ss << name << ": expected Variable at argument " << i << " (got None)";
      throw std::runtime_error(ss.str());
    }
  }
}
示例#2
0
static void _trace_create(PyObject* op_obj, THPFunction* bw_obj,
        PyObject *input_objects, PyObject *output_objects,
        const variable_list& input_vars, bool is_inplace) {
  if (!tracer::isTracing(input_vars))
    return;

  if (!op_obj) {
    std::ostringstream oss;
    oss << "Attempted to trace " << Py_TYPE(bw_obj)->tp_name;
    oss << ", but tracing of legacy functions is not supported";
    throw std::runtime_error(oss.str());
  }

  auto tracing_state = tracer::getTracingState(input_vars);
  bw_obj->is_traced = true;

  // Isolate C variable ptrs in a vector
  variable_list output_vars;
  for (int i = 0; i < PyTuple_GET_SIZE(output_objects); ++i) {
    THPVariable *var = (THPVariable*)PyTuple_GET_ITEM(output_objects, i);
    output_vars.emplace_back(var->cdata);
  }

  // Save scalar args and the calling convention
  auto num_args = PyTuple_GET_SIZE(input_objects);
  pyobj_list scalar_args;
  std::string arg_types;
  arg_types.reserve(num_args);
  scalar_args.reserve(num_args);
  for (int i = 0; i < num_args; i++) {
    PyObject *arg_object = PyTuple_GET_ITEM(input_objects, i);
    if (THPVariable_Check(arg_object)) {
      arg_types.push_back('t');
    } else {
      arg_types.push_back('s');
      Py_INCREF(arg_object);
      scalar_args.emplace_back(arg_object);
    }
  }

  auto state_lock = tracing_state->lock();

  // Note [getValueTrace can allocate nodes]
  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  // When an input variable is not traced, we create a constant instruction
  // to represent it.  This means that you must invoke getValueTrace() BEFORE
  // actually constructing the function that takes these variables as inputs.
  // If we do it the other order, the graph will be in the wrong topological
  // order.

  // See Note [getValueTrace can allocate nodes]
  std::vector<Value*> value_traces;
  value_traces.reserve(input_vars.size());
  for (auto& i : input_vars)
    value_traces.emplace_back(tracer::getValueTrace(tracing_state, i));

  // NB: this function is called only from THPFunction_apply, which is used only
  // when computing forward. All these functions are non-traceable by definition,
  // because they are implemented in terms of tensor operations. Hence, there's no
  // need for any conditionals in here and we can always create the node.

  // Construct the IR Node and its Selects
  Py_INCREF(op_obj);
  auto& graph = tracing_state->graph;
  auto this_expr = graph->appendNode(graph->createPythonOp(
    THPObjectPtr(op_obj),
    arg_types,
    false, // TODO: remove is_legacy
    std::move(scalar_args)));
  for (auto t : value_traces)
    this_expr->addInput(t);

  int num_outputs = output_vars.size();
  for (int i = 0; i < num_outputs; ++i) {
    auto& output = output_vars[i];
    // NOTE: normally we don't add Select nodes when there's only a single
    // output, but Python nodes can't be optimized away, so we simplify the
    // code here.
    auto sel = this_expr->addOutput();
    sel->inferTypeFrom(output.data());
    tracer::setValueTrace(tracing_state, output, sel);
  }
  this_expr->i_(kinplace, is_inplace);

  // See definition in function.cpp.
  THPObjectPtr passes_py_bool {PyObject_GetAttrString(op_obj, "is_traceable")};
  if (!passes_py_bool) throw python_error();
  bool passes_state_transparently = passes_py_bool == Py_True;
  // NB: this path is executed only for forward of Python functions, so there's no need to check
  // tracing_state->in_eval_subgraph (it's always false, because they are never part of backward
  // subgraphs AND we don't even materialize the forward function).
  if (!passes_state_transparently) {
    tracer::nontraceableBackwardSubgraph(input_vars, output_vars);
    Function::setUpContextEdge(this_expr, input_vars, output_vars);
  }
}