static TensorShape getShape(const tensorflow::NodeDef& node) { TensorShape resultShape; if(node.attr().count("shape") > 0) { auto shape = node.attr().at("shape").shape(); for(int i = 0; i < shape.dim_size(); i++) { resultShape.addDimension(shape.dim(i).size()); } } return resultShape; }
static void PopulateNodeDef( const std::wstring& name, const std::wstring& opName, DataType dataType, const std::vector<Variable>& outputs, tensorflow::NodeDef& dst) { dst.set_name(ToString(name)); dst.set_op(ToString(opName)); PopulateDataTypeAttr(dataType, *dst.mutable_attr()); PopulateOutputShapesAttr(outputs, *dst.mutable_attr()); PopulateShapeAttr(outputs[0].Shape(), *dst.mutable_attr()); }
static bool find_tensor_proto(const std::map<std::string, tensorflow::TensorProto>& weights, const tensorflow::NodeDef& node, tensorflow::TensorProto& tensor) { for (int j=0; j<node.input_size(); j++) { const std::string& input_name = node.input(j); const std::map<std::string, tensorflow::TensorProto>::const_iterator it = weights.find(input_name); if (it != weights.end()) { tensor = it->second; return true; } } return false; }
static bool find_attr_value(const tensorflow::NodeDef& node, const char* key, tensorflow::AttrValue& value) { const google::protobuf::Map<std::string, tensorflow::AttrValue>& attr = node.attr(); const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(key); if (it != attr.end()) { value = it->second; return true; } return false; }
static bool get_tensor_proto(const std::map<std::string, tensorflow::TensorProto>& consts, const tensorflow::NodeDef& node, tensorflow::TensorProto& tensor) { const std::string& output_name = node.name(); const std::map<std::string, tensorflow::TensorProto>::const_iterator it = consts.find(output_name); if (it != consts.end()) { tensor = it->second; return true; } return false; }