static PyObject *unpack_saved_variables( THPFunction *self, std::function<PyObject*(const Variable&)> unpack_fn) { THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE); auto& saved_variables = self->saved_variables; if (saved_variables.empty()) return PyTuple_New(0); int num_saved = saved_variables.size(); THPObjectPtr saved(PyTuple_New(num_saved)); if (!saved) return NULL; auto saved_for = THPFunction_asFunction(self); for (int i = 0; i < num_saved; i++) { auto unpacked_var = saved_variables[i].unpack(saved_for); THPObjectPtr value; if (!unpacked_var.defined()) { Py_INCREF(Py_None); value = Py_None; } else { value = unpack_fn(unpacked_var); } PyTuple_SET_ITEM(saved.get(), i, value.release()); } return saved.release(); }
PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) { HANDLE_TH_ERRORS torch::autograd::profiler::RecordFunction record(((PyTypeObject*)cls)->tp_name); THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls")); if (!backward_cls) return NULL; THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, NULL)); if (!ctx_obj) return NULL; THPFunction* ctx = (THPFunction*)ctx_obj.get(); // Prepare inputs and allocate context (grad fn) auto info_pair = unpack_input<false>(inputs); UnpackedInput& unpacked_input = info_pair.first; InputFlags& input_info = info_pair.second; // Initialize backward function (and ctx) bool is_volatile = input_info.flags.is_volatile; ctx->cdata.set_flags(std::move(input_info.flags)); ctx->needs_input_grad = input_info.needs_input_grad.release(); ctx->is_variable_input = std::move(input_info.is_variable_input); // Prepend ctx to tensor_input, in preparation for static method call auto num_args = PyTuple_GET_SIZE(inputs); THPObjectPtr ctx_tensor_input(PyTuple_New(num_args + 1)); PyTuple_SET_ITEM(ctx_tensor_input.get(), 0, ctx_obj.release()); for (int i = 0; i < num_args; ++i) { PyObject *arg = PyTuple_GET_ITEM(unpacked_input.tensor_input.get(), i); Py_INCREF(arg); PyTuple_SET_ITEM(ctx_tensor_input.get(), i + 1, arg); } // Call forward THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward")); if (!forward_fn) return NULL; THPObjectPtr tensor_outputs(PyObject_CallObject(forward_fn, ctx_tensor_input)); if (!tensor_outputs) return NULL; THPObjectPtr outputs {process_outputs(cls, ctx, unpacked_input, inputs, std::move(tensor_outputs), is_volatile)}; return outputs.release(); END_HANDLE_TH_ERRORS }