void printNodeInfo(const onnx::NodeProto &node) { std::cout << "name: " << node.name() << "op_type: " << node.op_type() << "domain: " << node.domain() << "input: " << node.input(0) << "output: " << node.output(0) << std::endl; }
inline auto make_softmax_primitive( std::unordered_map<std::string, const mkldnn::memory> const& /*parameter_memory_table*/, std::unordered_map<std::string, std::tuple<const mkldnn::memory, mkldnn::memory::format>> const& variable_memory_table, std::set<std::string> const& required_output_set, onnx::NodeProto const& node, mkldnn::engine const& engine) { constexpr auto softmax_axis = 1; auto const& input_memory_and_origin_format = find_value(variable_memory_table, node.input(0)); auto const& input_memory = std::get<0>(input_memory_and_origin_format); auto input_origin_format = std::get<1>(input_memory_and_origin_format); auto input_output_dims = extract_dims(input_memory); auto const& output_name = node.output(0); std::vector<mkldnn::memory> temp_variable_memory_list; // for temporary memory's life auto op_desc = mkldnn::softmax_forward::desc( mkldnn::prop_kind::forward_inference, input_memory.get_primitive_desc().desc(), softmax_axis); auto op_pd = mkldnn::softmax_forward::primitive_desc(op_desc, engine); std::vector<mkldnn::primitive> net; std::vector<std::pair< std::string, std::tuple<mkldnn::memory, mkldnn::memory::format>>> variable_memory_list; std::vector<std::pair<std::string, array>> output_name_and_arr_list; manage_output_memory(required_output_set, output_name, dtype_t::float_, input_output_dims, input_origin_format, input_memory.get_primitive_desc(), variable_memory_list, temp_variable_memory_list, output_name_and_arr_list, net, engine, [&input_memory, &op_pd](auto& op_output_memory) { return mkldnn::softmax_forward( op_pd, input_memory, op_output_memory); }); return std::make_tuple(net, variable_memory_list, temp_variable_memory_list, output_name_and_arr_list); }
inline auto make_parameter_memory_pair( onnx::NodeProto const& node, int param_index, mkldnn::memory::format format, std::unordered_map<std::string, instant::array> const& parameter_table, mkldnn::engine const& engine) { auto const& name = node.input(param_index); auto const& arr = find_value(parameter_table, name); mkldnn::memory::dims tz(arr.dims().begin(), arr.dims().end()); auto mem = mkldnn::memory( {{{tz}, mkldnn::memory::data_type::f32, format}, engine}, const_cast<void*>(arr.data())); return std::make_pair(name, mem); }
inline auto make_fc_primitive( std::unordered_map<std::string, const mkldnn::memory> const& parameter_memory_table, std::unordered_map<std::string, std::tuple<const mkldnn::memory, mkldnn::memory::format>> const& variable_memory_table, std::set<std::string> const& required_output_set, onnx::NodeProto const& node, mkldnn::engine const& engine) { auto attribute_table = instant::make_attribute_table(node); auto axis = load_attribute_int(attribute_table, "axis"); assert(axis == 1); auto axis_w = load_attribute_int(attribute_table, "axis_w"); assert(axis_w == 1); auto const& input_memory_and_origin_format = find_value(variable_memory_table, node.input(0)); auto const& input_memory = std::get<0>(input_memory_and_origin_format); auto input_origin_format = std::get<1>(input_memory_and_origin_format); auto const& weight_memory = find_value(parameter_memory_table, node.input(1)); auto const& bias_memory = find_value(parameter_memory_table, node.input(2)); auto input_dims = extract_dims(input_memory); auto weight_dims = extract_dims(weight_memory); auto bias_dims = extract_dims(bias_memory); mkldnn::memory::dims output_dims{input_dims[0], bias_dims[0]}; auto const& output_name = node.output(0); auto fc_input_md = mkldnn::memory::desc({input_dims}, mkldnn::memory::data_type::f32, mkldnn::memory::format::any); auto fc_weight_md = mkldnn::memory::desc({weight_dims}, mkldnn::memory::data_type::f32, mkldnn::memory::format::any); auto fc_output_md = mkldnn::memory::desc({output_dims}, mkldnn::memory::data_type::f32, mkldnn::memory::format::any); mkldnn::inner_product_forward::desc fc_desc( mkldnn::prop_kind::forward_inference, fc_input_md, fc_weight_md, bias_memory.get_primitive_desc().desc(), fc_output_md); auto fc_pd = mkldnn::inner_product_forward::primitive_desc(fc_desc, engine); std::vector<mkldnn::primitive> net; std::vector<mkldnn::memory> temp_variable_memory_list; // for temporary memory's life auto fc_input_memory = input_memory; if(mkldnn::memory::primitive_desc(fc_pd.src_primitive_desc()) != input_memory.get_primitive_desc()) { fc_input_memory = mkldnn::memory(fc_pd.src_primitive_desc()); temp_variable_memory_list.push_back(fc_input_memory); net.push_back(mkldnn::reorder(input_memory, fc_input_memory)); } auto fc_weight_memory = weight_memory; if(mkldnn::memory::primitive_desc(fc_pd.weights_primitive_desc()) != weight_memory.get_primitive_desc()) { fc_weight_memory = mkldnn::memory(fc_pd.weights_primitive_desc()); temp_variable_memory_list.push_back(fc_weight_memory); net.push_back(mkldnn::reorder(weight_memory, fc_weight_memory)); } std::vector<std::pair< std::string, std::tuple<mkldnn::memory, mkldnn::memory::format>>> variable_memory_list; std::vector<std::pair<std::string, array>> output_name_and_arr_list; manage_output_memory( required_output_set, output_name, dtype_t::float_, output_dims, input_origin_format, fc_pd.dst_primitive_desc(), variable_memory_list, temp_variable_memory_list, output_name_and_arr_list, net, engine, [&fc_pd, &fc_input_memory, &fc_weight_memory, &bias_memory](auto& op_output_memory) { return mkldnn::inner_product_forward( fc_pd, fc_input_memory, fc_weight_memory, bias_memory, op_output_memory); }); return std::make_tuple(net, variable_memory_list, temp_variable_memory_list, output_name_and_arr_list); }