Beispiel #1
0
    Status TypeUtils::GetType(const AttributeProto& p_attr, AttrType& p_type)
    {
        if (!OpSignature::IsValidAttribute(p_attr))
        {
            return Status(false, "Invalid AttributeProto.");
        }

        if (p_attr.has_f())
        {
            p_type = AttrType::FLOAT;
        }
        else if (p_attr.has_i())
        {
            p_type = AttrType::INT;
        }
        else if (p_attr.has_s())
        {
            p_type = AttrType::STRING;
        }
        else if (p_attr.has_t())
        {
            p_type = AttrType::TENSOR;
        }
        else if (p_attr.has_g())
        {
            p_type = AttrType::GRAPH;
        }
        else if (p_attr.floats_size())
        {
            p_type = AttrType::FLOATS;
        }
        else if (p_attr.ints_size())
        {
            p_type = AttrType::INTS;
        }
        else if (p_attr.strings_size())
        {
            p_type = AttrType::STRINGS;
        }
        else if (p_attr.tensors_size())
        {
            p_type = AttrType::TENSORS;
        }
        else if (p_attr.graphs_size())
        {
            p_type = AttrType::GRAPHS;
        }
        else if (p_attr.has_type())
        {
            p_type = AttrType::TYPE;
        }
        else if (p_attr.types_size())
        {
            p_type = AttrType::TYPES;
        }
        else if (p_attr.has_shape())
        {
            p_type = AttrType::SHAPE;
        }
        else if (p_attr.has_shape())
        {
            p_type = AttrType::SHAPES;
        }
        else
        {
            p_type = AttrType::NONE;
            return Status(false, "Invalid AttributeProto.");
        }

        return Status::OK();
    }
Beispiel #2
0
    Status TypeUtils::GetType(const AttributeProto& p_attr, AttrType& p_type)
    {
        if (!OpSignature::IsValidAttribute(p_attr))
        {
            return Status(ONNX, FAIL, "Invalid AttributeProto.");
        }

        p_type = p_attr.type();
        if (AttrType::AttributeProto_AttributeType_UNDEFINED == p_type)
        {
            if (p_attr.has_f())
            {
                p_type = AttrType::AttributeProto_AttributeType_FLOAT;
            }
            else if (p_attr.has_i())
            {
                p_type = AttrType::AttributeProto_AttributeType_INT;
            }
            else if (p_attr.has_s())
            {
                p_type = AttrType::AttributeProto_AttributeType_STRING;
            }
            else if (p_attr.has_t())
            {
                p_type = AttrType::AttributeProto_AttributeType_TENSOR;
            }
            else if (p_attr.has_g())
            {
                p_type = AttrType::AttributeProto_AttributeType_GRAPH;
            }
            else if (p_attr.floats_size())
            {
                p_type = AttrType::AttributeProto_AttributeType_FLOATS;
            }
            else if (p_attr.ints_size())
            {
                p_type = AttrType::AttributeProto_AttributeType_INTS;
            }
            else if (p_attr.strings_size())
            {
                p_type = AttrType::AttributeProto_AttributeType_STRINGS;
            }
            else if (p_attr.tensors_size())
            {
                p_type = AttrType::AttributeProto_AttributeType_TENSORS;
            }
            else if (p_attr.graphs_size())
            {
                p_type = AttrType::AttributeProto_AttributeType_GRAPHS;
            }
            else
            {
                return Status(ONNX, FAIL, "Invalid AttributeProto.");
            }
        }
        return Status::OK();
    }
Beispiel #3
0
Datei: op.cpp Projekt: Shrsh/CNTK
bool TypeUtils::IsValidAttribute(const AttributeProto& attr)
{
    if (attr.name().empty())
    {
        return false;
    }

    if (attr.type() == AttributeProto_AttributeType_UNDEFINED)
    {
        const int num_fields =
            attr.has_f() +
            attr.has_i() +
            attr.has_s() +
            attr.has_t() +
            attr.has_g() +
            (attr.floats_size() > 0) +
            (attr.ints_size() > 0) +
            (attr.strings_size() > 0) +
            (attr.tensors_size() > 0) +
            (attr.graphs_size() > 0);

        if (num_fields != 1)
        {
            return false;
        }
    }
    return true;
}