void check_input_variables(const char* name, const variable_list& inputs, int args, int required_args) { if (required_args == -1) { required_args = args; } if (inputs.size() != (size_t)args) { std::stringstream ss; ss << name << ": expected " << args << " arguments (got " << inputs.size(); ss << ")"; throw std::runtime_error(ss.str()); } for (int i = 0; i < required_args; ++i) { if (!inputs[i].defined()) { std::stringstream ss; ss << name << ": expected Variable at argument " << i << " (got None)"; throw std::runtime_error(ss.str()); } } }
static void _trace_create(PyObject* op_obj, THPFunction* bw_obj, PyObject *input_objects, PyObject *output_objects, const variable_list& input_vars, bool is_inplace) { if (!tracer::isTracing(input_vars)) return; if (!op_obj) { std::ostringstream oss; oss << "Attempted to trace " << Py_TYPE(bw_obj)->tp_name; oss << ", but tracing of legacy functions is not supported"; throw std::runtime_error(oss.str()); } auto tracing_state = tracer::getTracingState(input_vars); bw_obj->is_traced = true; // Isolate C variable ptrs in a vector variable_list output_vars; for (int i = 0; i < PyTuple_GET_SIZE(output_objects); ++i) { THPVariable *var = (THPVariable*)PyTuple_GET_ITEM(output_objects, i); output_vars.emplace_back(var->cdata); } // Save scalar args and the calling convention auto num_args = PyTuple_GET_SIZE(input_objects); pyobj_list scalar_args; std::string arg_types; arg_types.reserve(num_args); scalar_args.reserve(num_args); for (int i = 0; i < num_args; i++) { PyObject *arg_object = PyTuple_GET_ITEM(input_objects, i); if (THPVariable_Check(arg_object)) { arg_types.push_back('t'); } else { arg_types.push_back('s'); Py_INCREF(arg_object); scalar_args.emplace_back(arg_object); } } auto state_lock = tracing_state->lock(); // Note [getValueTrace can allocate nodes] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // When an input variable is not traced, we create a constant instruction // to represent it. This means that you must invoke getValueTrace() BEFORE // actually constructing the function that takes these variables as inputs. // If we do it the other order, the graph will be in the wrong topological // order. // See Note [getValueTrace can allocate nodes] std::vector<Value*> value_traces; value_traces.reserve(input_vars.size()); for (auto& i : input_vars) value_traces.emplace_back(tracer::getValueTrace(tracing_state, i)); // NB: this function is called only from THPFunction_apply, which is used only // when computing forward. All these functions are non-traceable by definition, // because they are implemented in terms of tensor operations. Hence, there's no // need for any conditionals in here and we can always create the node. // Construct the IR Node and its Selects Py_INCREF(op_obj); auto& graph = tracing_state->graph; auto this_expr = graph->appendNode(graph->createPythonOp( THPObjectPtr(op_obj), arg_types, false, // TODO: remove is_legacy std::move(scalar_args))); for (auto t : value_traces) this_expr->addInput(t); int num_outputs = output_vars.size(); for (int i = 0; i < num_outputs; ++i) { auto& output = output_vars[i]; // NOTE: normally we don't add Select nodes when there's only a single // output, but Python nodes can't be optimized away, so we simplify the // code here. auto sel = this_expr->addOutput(); sel->inferTypeFrom(output.data()); tracer::setValueTrace(tracing_state, output, sel); } this_expr->i_(kinplace, is_inplace); // See definition in function.cpp. THPObjectPtr passes_py_bool {PyObject_GetAttrString(op_obj, "is_traceable")}; if (!passes_py_bool) throw python_error(); bool passes_state_transparently = passes_py_bool == Py_True; // NB: this path is executed only for forward of Python functions, so there's no need to check // tracing_state->in_eval_subgraph (it's always false, because they are never part of backward // subgraphs AND we don't even materialize the forward function). if (!passes_state_transparently) { tracer::nontraceableBackwardSubgraph(input_vars, output_vars); Function::setUpContextEdge(this_expr, input_vars, output_vars); } }