Ejemplo n.º 1
0
void initPythonTracerBindings(PyObject* module_) {
  setRecordSourceLocation(pythonRecordSourceLocation);

  auto m = py::handle(module_).cast<py::module>();
  py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
    // NB: no constructor; you have to get it from C++ code
    .def("__repr__", [](const TracingState& s) {
      std::ostringstream ss;
      ss << "<TracingState " << (const void*)&s << ">";
      return ss.str();
    })
    .def("__str__", [](const TracingState& s) -> std::string {
      if (s.is_expired()) return "<expired TracingState>";
      std::ostringstream ss;
      ss << *s.graph;
      return ss.str();
    })
    .def("push_scope", [](TracingState& s, const std::string& scope_name) {
      ASSERT_UNEXPIRED("push_scope");
      s.push_scope(scope_name);
    })
    .def("pop_scope", [](TracingState& s) {
      ASSERT_UNEXPIRED("pop_scope");
      s.pop_scope();
    })
    .def("set_graph", [](TracingState& s, std::shared_ptr<Graph> g) {
      s.graph = g;
    })
    .def("graph", [](TracingState& s) {
      return s.graph;
    })
    .def_property_readonly("is_expired", [](TracingState& s) {
      return s.is_expired();
    })
    .def_property_readonly("is_complete", [](TracingState& s) {
      return s.is_complete();
    });

  m.def("_tracer_enter", [](variable_list trace_inputs, size_t num_backwards) {
    return tracer::enter(std::move(trace_inputs), num_backwards + 1);
  });
  m.def("_tracer_exit", [](variable_list var_outputs) {
    tracer::exit(var_outputs);
  });
  m.def("_get_tracing_state", [](const variable_list& vars) {
    return getTracingState(vars);
  });
  m.def("_get_value_trace", [](std::shared_ptr<TracingState>& state, const Variable& var) {
    return getValueTrace(state, var);
  });
  m.def("_set_value_trace", [](std::shared_ptr<TracingState>& state, const Variable& var, Value* value) {
    return setValueTrace(state, var, value);
  });
  m.def("_is_tracing", [](const variable_list& vars) {
    return isTracingVar(vars);
  });
}
Ejemplo n.º 2
0
void initPythonTracerBindings(PyObject* module_) {
  auto m = py::handle(module_).cast<py::module>();
  py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
    // NB: no constructor; you have to get it from C++ code
    .def("__repr__", [](const TracingState& s) {
      std::ostringstream ss;
      ss << "<TracingState " << (const void*)&s << ">";
      return ss.str();
    })
    .def("__str__", [](const TracingState& s) -> std::string {
      if (s.is_expired()) return "<expired TracingState>";
      std::ostringstream ss;
      ss << *s.graph;
      return ss.str();
    })
    .def("push_scope", [](TracingState& s, const std::string& scope_name) {
      ASSERT_UNEXPIRED("push_scope");
      s.push_scope(scope_name);
    })
    .def("pop_scope", [](TracingState& s) {
      ASSERT_UNEXPIRED("pop_scope");
      s.pop_scope();
    })
    .def("export", [](TracingState& s, const std::vector<at::Tensor>& initializers, int64_t onnx_opset_version) {
      ASSERT_UNEXPIRED("export");
      return py::bytes(ExportGraph(s.graph, initializers, onnx_opset_version));
    })
    .def("graph", [](TracingState& s) {
      return s.graph;
    })
    .def_property_readonly("is_expired", [](TracingState& s) {
      return s.is_expired();
    })
    .def_property_readonly("is_complete", [](TracingState& s) {
      return s.is_complete();
    });

  m.def("_tracer_enter", [](std::vector<TraceInput> trace_inputs, std::size_t num_backwards) {
    return enter(std::move(trace_inputs), num_backwards + 1);
  });
  m.def("_tracer_exit", [](variable_list var_outputs) {
    tracer::exit(var_outputs);
  });
  m.def("_get_tracing_state", [](const variable_list& vars) {
    return getTracingState(vars);
  });
  m.def("_is_tracing", [](const variable_list& vars) {
    return isTracingVar(vars);
  });
}