示例#1
0
static TensorShape getShape(const tensorflow::NodeDef& node) {
    TensorShape resultShape;
    if(node.attr().count("shape") > 0) {
        auto shape = node.attr().at("shape").shape();
        for(int i = 0; i < shape.dim_size(); i++) {
            resultShape.addDimension(shape.dim(i).size());
        }
    }
    return resultShape;
}
示例#2
0
        static void PopulateNodeDef(
            const std::wstring& name,
            const std::wstring& opName,
            DataType dataType,
            const std::vector<Variable>& outputs,
            tensorflow::NodeDef& dst)
        {
            dst.set_name(ToString(name));
            dst.set_op(ToString(opName));

            PopulateDataTypeAttr(dataType, *dst.mutable_attr());
            PopulateOutputShapesAttr(outputs, *dst.mutable_attr());
            PopulateShapeAttr(outputs[0].Shape(), *dst.mutable_attr());
        }
示例#3
0
static bool find_tensor_proto(const std::map<std::string, tensorflow::TensorProto>& weights,
                              const tensorflow::NodeDef& node, tensorflow::TensorProto& tensor)
{
    for (int j=0; j<node.input_size(); j++)
    {
        const std::string& input_name = node.input(j);

        const std::map<std::string, tensorflow::TensorProto>::const_iterator it = weights.find(input_name);
        if (it != weights.end())
        {
            tensor = it->second;
            return true;
        }
    }

    return false;
}
示例#4
0
static bool find_attr_value(const tensorflow::NodeDef& node, const char* key, tensorflow::AttrValue& value)
{
    const google::protobuf::Map<std::string, tensorflow::AttrValue>& attr = node.attr();

    const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(key);
    if (it != attr.end())
    {
        value = it->second;
        return true;
    }

    return false;
}
示例#5
0
static bool get_tensor_proto(const std::map<std::string, tensorflow::TensorProto>& consts,
                             const tensorflow::NodeDef& node, tensorflow::TensorProto& tensor)
{
    const std::string& output_name = node.name();

    const std::map<std::string, tensorflow::TensorProto>::const_iterator it = consts.find(output_name);
    if (it != consts.end())
    {
        tensor = it->second;
        return true;
    }

    return false;
}