static TensorShape getShape(const tensorflow::NodeDef& node) { TensorShape resultShape; if(node.attr().count("shape") > 0) { auto shape = node.attr().at("shape").shape(); for(int i = 0; i < shape.dim_size(); i++) { resultShape.addDimension(shape.dim(i).size()); } } return resultShape; }
static bool find_attr_value(const tensorflow::NodeDef& node, const char* key, tensorflow::AttrValue& value) { const google::protobuf::Map<std::string, tensorflow::AttrValue>& attr = node.attr(); const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(key); if (it != attr.end()) { value = it->second; return true; } return false; }