예제 #1
0
파일: main.cpp 프로젝트: Daiver/jff
int main()
{
    GOOGLE_PROTOBUF_VERIFY_VERSION;
    onnx::ModelProto modelProto;
    std::fstream input("alexnet.proto", std::ios::in | std::ios::binary);
    if (!parseBigMessageFromIstream(&modelProto, &input)) {
        std::cerr << "Failed to parse model." << std::endl;
        return -1;
    }

    printModelInfo(modelProto);
    const onnx::GraphProto graph = modelProto.graph();
    printGraphInfo(graph);
    printNodeInfo(graph.node(0));

    return 0;
}
예제 #2
0
    inline auto make_nets(
      onnx::GraphProto const& graph,
      std::unordered_map<std::string, const mkldnn::memory> const&
        parameter_memory_table,
      std::unordered_map<
        std::string, std::tuple<const mkldnn::memory, mkldnn::memory::format>>&
        input_memory_table,
      std::set<std::string> const& required_output_set,
      std::unordered_map<std::string, primitive_factory>
        primitive_factory_table =
          instant::make_default_primitive_factory_table(),
      instant::context const& context = instant::get_context()) {
        auto variable_memory_table = input_memory_table;
        std::unordered_map<std::string, instant::array> output_table;
        std::vector<mkldnn::primitive> nets;
        std::vector<mkldnn::memory> temp_variable_memory_list;
        for(auto const& node : graph.node()) {
            try {
                auto primitive_factory_pair_iter =
                  primitive_factory_table.find(node.op_type());
                if(primitive_factory_pair_iter ==
                   primitive_factory_table.end()) {
                    throw std::runtime_error("Implementation not found: " +
                                             node.op_type());
                }
                auto temp_tuple =
                  primitive_factory_pair_iter->second.operator()(
                    parameter_memory_table, variable_memory_table,
                    required_output_set, node, context.engine());
                auto& net = std::get<0>(temp_tuple);
                auto& output_name_and_memory_and_origin_format_list =
                  std::get<1>(temp_tuple);
                auto& temp_vars = std::get<2>(temp_tuple);
                auto& output_name_and_arr_list = std::get<3>(temp_tuple);

                nets.insert(nets.end(), std::make_move_iterator(net.begin()),
                            std::make_move_iterator(net.end()));
                variable_memory_table.insert(
                  std::make_move_iterator(
                    output_name_and_memory_and_origin_format_list.begin()),
                  std::make_move_iterator(
                    output_name_and_memory_and_origin_format_list.end()));
                temp_variable_memory_list.insert(
                  temp_variable_memory_list.end(),
                  std::make_move_iterator(temp_vars.begin()),
                  std::make_move_iterator(temp_vars.end()));
                output_table.insert(
                  std::make_move_iterator(output_name_and_arr_list.begin()),
                  std::make_move_iterator(output_name_and_arr_list.end()));
            } catch(mkldnn::error const& e) {
                std::cout << "MKLDNN Error: " << e.message << std::endl;
            } catch(std::exception const& e) {
                std::cout << "Error: " << e.what() << std::endl;
            }
        }
        return std::make_tuple(nets, variable_memory_table,
                               temp_variable_memory_list, output_table);
    }
예제 #3
0
파일: main.cpp 프로젝트: Daiver/jff
void printGraphInfo(const onnx::GraphProto &graph)
{
    std::cout << "graph name " << graph.name() << std::endl;
    std::cout << "doc str: " << graph.doc_string() << std::endl;
    std::cout << "nNodes: " << graph.node_size() << std::endl;
    std::cout << "nInitializers: " << graph.initializer_size() << std::endl;
    std::cout << "Inputs:" << std::endl;
    for(int i = 0; i < graph.input_size(); ++i){
        std::cout << "\tname: " << graph.input(i).name() << std::endl;
    }

    std::cout << "Output:" << std::endl;
    for(int i = 0; i < graph.output_size(); ++i){
        std::cout << "\tname: " << graph.output(i).name() << std::endl;
    }
}
예제 #4
0
    inline auto make_parameter_memory_table(
      onnx::GraphProto const& graph,
      std::unordered_map<std::string, instant::array> const& parameter_table,
      mkldnn::engine const& engine) {
        std::unordered_map<std::string, const mkldnn::memory> memory_table;
        std::vector<array> temp_array_list;
        for(auto const& node : graph.node()) {
            if(node.op_type() == "Conv") {
                constexpr auto weight_index = 1;
                memory_table.insert(make_parameter_memory_pair(
                  node, weight_index, mkldnn::memory::format::oihw,
                  parameter_table, engine));

                if(node.input_size() != 2) {
                    constexpr auto bias_index = 2;
                    memory_table.insert(make_parameter_memory_pair(
                      node, bias_index, mkldnn::memory::format::x,
                      parameter_table, engine));
                }
            } else if(node.op_type() == "FC") {
                constexpr auto weight_index = 1;
                constexpr auto bias_index = 2;
                memory_table.insert(make_parameter_memory_pair(
                  node, weight_index,
                  mkldnn::memory::format::oi, // MEMO: is it correct? result is
                                              // correct...
                  parameter_table, engine));
                memory_table.insert(make_parameter_memory_pair(
                  node, bias_index, mkldnn::memory::format::x, parameter_table,
                  engine));
            } else if(node.op_type() == "BatchNormalization") {
                constexpr auto scale_index = 1;
                constexpr auto b_index = 2;
                constexpr auto mean_index = 3;
                constexpr auto var_index = 4;

                auto const& scale_name = node.input(scale_index);
                auto const& scale_arr = find_value(parameter_table, scale_name);
                auto const& b_name = node.input(b_index);
                auto const& b_arr = find_value(parameter_table, b_name);
                mkldnn::memory::dims scale_dims(scale_arr.dims().begin(),
                                                scale_arr.dims().end());
                std::vector<int> weights_dims{{2}};
                weights_dims.insert(weights_dims.end(), scale_dims.begin(),
                                    scale_dims.end());
                array weights_arr(dtype_t::float_, weights_dims);
                std::copy(fbegin(scale_arr), fend(scale_arr),
                          fbegin(weights_arr));
                std::copy(fbegin(b_arr), fend(b_arr),
                          fbegin(weights_arr) + calc_total_size(scale_dims));
                temp_array_list.push_back(weights_arr);
                auto weights_mem =
                  mkldnn::memory({{{weights_dims},
                                   mkldnn::memory::data_type::f32,
                                   mkldnn::memory::format::nc},
                                  engine},
                                 weights_arr.data());
                memory_table.insert({scale_name, weights_mem});

                /*
                memory_table.insert(make_parameter_memory_pair(
                  node, scale_index, mkldnn::memory::format::x, parameter_table,
                  engine));
                memory_table.insert(make_parameter_memory_pair(
                  node, b_index, mkldnn::memory::format::x, parameter_table,
                  engine));
                */
                memory_table.insert(make_parameter_memory_pair(
                  node, mean_index, mkldnn::memory::format::x, parameter_table,
                  engine));
                memory_table.insert(make_parameter_memory_pair(
                  node, var_index, mkldnn::memory::format::x, parameter_table,
                  engine));
            } else {
                // TODO
                /*
                throw std::runtime_error("Not implemented yet: " +
                                         node.op_type());
                */
            }
        }
        return std::make_tuple(memory_table, temp_array_list);
    } // namespace instant