Esempio n. 1
0
File: op.cpp Progetto: Shrsh/CNTK
bool TypeUtils::IsValidAttribute(const AttributeProto& attr)
{
    if (attr.name().empty())
    {
        return false;
    }

    if (attr.type() == AttributeProto_AttributeType_UNDEFINED)
    {
        const int num_fields =
            attr.has_f() +
            attr.has_i() +
            attr.has_s() +
            attr.has_t() +
            attr.has_g() +
            (attr.floats_size() > 0) +
            (attr.ints_size() > 0) +
            (attr.strings_size() > 0) +
            (attr.tensors_size() > 0) +
            (attr.graphs_size() > 0);

        if (num_fields != 1)
        {
            return false;
        }
    }
    return true;
}
Esempio n. 2
0
File: op.cpp Progetto: maxiang/CNTK
    Status TypeUtils::GetType(const AttributeProto& p_attr, AttrType& p_type)
    {
        if (!OpSignature::IsValidAttribute(p_attr))
        {
            return Status(ONNX, FAIL, "Invalid AttributeProto.");
        }

        p_type = p_attr.type();
        if (AttrType::AttributeProto_AttributeType_UNDEFINED == p_type)
        {
            if (p_attr.has_f())
            {
                p_type = AttrType::AttributeProto_AttributeType_FLOAT;
            }
            else if (p_attr.has_i())
            {
                p_type = AttrType::AttributeProto_AttributeType_INT;
            }
            else if (p_attr.has_s())
            {
                p_type = AttrType::AttributeProto_AttributeType_STRING;
            }
            else if (p_attr.has_t())
            {
                p_type = AttrType::AttributeProto_AttributeType_TENSOR;
            }
            else if (p_attr.has_g())
            {
                p_type = AttrType::AttributeProto_AttributeType_GRAPH;
            }
            else if (p_attr.floats_size())
            {
                p_type = AttrType::AttributeProto_AttributeType_FLOATS;
            }
            else if (p_attr.ints_size())
            {
                p_type = AttrType::AttributeProto_AttributeType_INTS;
            }
            else if (p_attr.strings_size())
            {
                p_type = AttrType::AttributeProto_AttributeType_STRINGS;
            }
            else if (p_attr.tensors_size())
            {
                p_type = AttrType::AttributeProto_AttributeType_TENSORS;
            }
            else if (p_attr.graphs_size())
            {
                p_type = AttrType::AttributeProto_AttributeType_GRAPHS;
            }
            else
            {
                return Status(ONNX, FAIL, "Invalid AttributeProto.");
            }
        }
        return Status::OK();
    }