示例#1
0
static PyObject *unpack_saved_variables(
    THPFunction *self,
    std::function<PyObject*(const Variable&)> unpack_fn)
{
  THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE);
  auto& saved_variables = self->saved_variables;
  if (saved_variables.empty())
    return PyTuple_New(0);

  int num_saved = saved_variables.size();
  THPObjectPtr saved(PyTuple_New(num_saved));
  if (!saved)
    return NULL;
  auto saved_for = THPFunction_asFunction(self);
  for (int i = 0; i < num_saved; i++) {
    auto unpacked_var = saved_variables[i].unpack(saved_for);
    THPObjectPtr value;
    if (!unpacked_var.defined()) {
      Py_INCREF(Py_None);
      value = Py_None;
    } else {
      value = unpack_fn(unpacked_var);
    }
    PyTuple_SET_ITEM(saved.get(), i, value.release());
  }
  return saved.release();
}
示例#2
0
PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
{
  HANDLE_TH_ERRORS
  torch::autograd::profiler::RecordFunction record(((PyTypeObject*)cls)->tp_name);

  THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
  if (!backward_cls) return NULL;
  THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, NULL));
  if (!ctx_obj) return NULL;
  THPFunction* ctx = (THPFunction*)ctx_obj.get();

  // Prepare inputs and allocate context (grad fn)
  auto info_pair = unpack_input<false>(inputs);
  UnpackedInput& unpacked_input = info_pair.first;
  InputFlags& input_info = info_pair.second;

  // Initialize backward function (and ctx)
  bool is_volatile = input_info.flags.is_volatile;
  ctx->cdata.set_flags(std::move(input_info.flags));
  ctx->needs_input_grad = input_info.needs_input_grad.release();
  ctx->is_variable_input = std::move(input_info.is_variable_input);

  // Prepend ctx to tensor_input, in preparation for static method call
  auto num_args = PyTuple_GET_SIZE(inputs);
  THPObjectPtr ctx_tensor_input(PyTuple_New(num_args + 1));
  PyTuple_SET_ITEM(ctx_tensor_input.get(), 0, ctx_obj.release());
  for (int i = 0; i < num_args; ++i) {
    PyObject *arg = PyTuple_GET_ITEM(unpacked_input.tensor_input.get(), i);
    Py_INCREF(arg);
    PyTuple_SET_ITEM(ctx_tensor_input.get(), i + 1, arg);
  }

  // Call forward
  THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward"));
  if (!forward_fn) return NULL;
  THPObjectPtr tensor_outputs(PyObject_CallObject(forward_fn, ctx_tensor_input));
  if (!tensor_outputs) return NULL;

  THPObjectPtr outputs {process_outputs(cls, ctx, unpacked_input, inputs,
                                        std::move(tensor_outputs), is_volatile)};

  return outputs.release();
  END_HANDLE_TH_ERRORS
}