DataTypePtr recursiveRemoveLowCardinality(const DataTypePtr & type)
{
    if (!type)
        return type;

    if (const auto * array_type = typeid_cast<const DataTypeArray *>(type.get()))
        return std::make_shared<DataTypeArray>(recursiveRemoveLowCardinality(array_type->getNestedType()));

    if (const auto * tuple_type = typeid_cast<const DataTypeTuple *>(type.get()))
    {
        DataTypes elements = tuple_type->getElements();
        for (auto & element : elements)
            element = recursiveRemoveLowCardinality(element);

        if (tuple_type->haveExplicitNames())
            return std::make_shared<DataTypeTuple>(elements, tuple_type->getElementNames());
        else
            return std::make_shared<DataTypeTuple>(elements);
    }

    if (const auto * low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(type.get()))
        return low_cardinality_type->getDictionaryType();

    return type;
}
示例#2
0
ColumnPtr FunctionArrayIntersect::castRemoveNullable(const ColumnPtr & column, const DataTypePtr & data_type) const
{
    if (auto column_nullable = checkAndGetColumn<ColumnNullable>(column.get()))
    {
        auto nullable_type = checkAndGetDataType<DataTypeNullable>(data_type.get());
        const auto & nested = column_nullable->getNestedColumnPtr();
        if (nullable_type)
        {
            auto casted_column = castRemoveNullable(nested, nullable_type->getNestedType());
            return ColumnNullable::create(casted_column, column_nullable->getNullMapColumnPtr());
        }
        return castRemoveNullable(nested, data_type);
    }
    else if (auto column_array = checkAndGetColumn<ColumnArray>(column.get()))
    {
        auto array_type = checkAndGetDataType<DataTypeArray>(data_type.get());
        if (!array_type)
            throw Exception{"Cannot cast array column to column with type "
                            + data_type->getName() + " in function " + getName(), ErrorCodes::LOGICAL_ERROR};

        auto casted_column = castRemoveNullable(column_array->getDataPtr(), array_type->getNestedType());
        return ColumnArray::create(casted_column, column_array->getOffsetsPtr());
    }
    else if (auto column_tuple = checkAndGetColumn<ColumnTuple>(column.get()))
    {
        auto tuple_type = checkAndGetDataType<DataTypeTuple>(data_type.get());

        if (!tuple_type)
            throw Exception{"Cannot cast tuple column to type "
                            + data_type->getName() + " in function " + getName(), ErrorCodes::LOGICAL_ERROR};

        auto columns_number = column_tuple->getColumns().size();
        Columns columns(columns_number);

        const auto & types = tuple_type->getElements();

        for (auto i : ext::range(0, columns_number))
        {
            columns[i] = castRemoveNullable(column_tuple->getColumnPtr(i), types[i]);
        }
        return ColumnTuple::create(columns);
    }

    return column;
}
示例#3
0
Columns FunctionArrayIntersect::castColumns(
        Block & block, const ColumnNumbers & arguments, const DataTypePtr & return_type,
        const DataTypePtr & return_type_with_nulls) const
{
    size_t num_args = arguments.size();
    Columns columns(num_args);

    auto type_array = checkAndGetDataType<DataTypeArray>(return_type.get());
    auto & type_nested = type_array->getNestedType();
    auto type_not_nullable_nested = removeNullable(type_nested);

    const bool is_numeric_or_string = isNumber(type_not_nullable_nested)
                                      || isDateOrDateTime(type_not_nullable_nested)
                                      || isStringOrFixedString(type_not_nullable_nested);

    DataTypePtr nullable_return_type;

    if (is_numeric_or_string)
    {
        auto type_nullable_nested = makeNullable(type_nested);
        nullable_return_type = std::make_shared<DataTypeArray>(type_nullable_nested);
    }

    const bool nested_is_nullable = type_nested->isNullable();

    for (size_t i = 0; i < num_args; ++i)
    {
        const ColumnWithTypeAndName & arg = block.getByPosition(arguments[i]);
        auto & column = columns[i];

        if (is_numeric_or_string)
        {
            /// Cast to Array(T) or Array(Nullable(T)).
            if (nested_is_nullable)
            {
                if (arg.type->equals(*return_type))
                    column = arg.column;
                else
                    column = castColumn(arg, return_type, context);
            }
            else
            {
                /// If result has array type Array(T) still cast Array(Nullable(U)) to Array(Nullable(T))
                ///  because cannot cast Nullable(T) to T.
                if (arg.type->equals(*return_type) || arg.type->equals(*nullable_return_type))
                    column = arg.column;
                else if (static_cast<const DataTypeArray &>(*arg.type).getNestedType()->isNullable())
                    column = castColumn(arg, nullable_return_type, context);
                else
                    column = castColumn(arg, return_type, context);
            }
        }
        else
        {
            /// return_type_with_nulls is the most common subtype with possible nullable parts.
            if (arg.type->equals(*return_type_with_nulls))
                column = arg.column;
            else
                column = castColumn(arg, return_type_with_nulls, context);
        }
    }

    return columns;
}
ColumnPtr recursiveLowCardinalityConversion(const ColumnPtr & column, const DataTypePtr & from_type, const DataTypePtr & to_type)
{
    if (!column)
        return column;

    if (from_type->equals(*to_type))
        return column;

    if (const auto * column_const = typeid_cast<const ColumnConst *>(column.get()))
        return ColumnConst::create(recursiveLowCardinalityConversion(column_const->getDataColumnPtr(), from_type, to_type),
                                   column_const->size());

    if (const auto * low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(from_type.get()))
    {
        if (to_type->equals(*low_cardinality_type->getDictionaryType()))
            return column->convertToFullColumnIfLowCardinality();
    }

    if (const auto * low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(to_type.get()))
    {
        if (from_type->equals(*low_cardinality_type->getDictionaryType()))
        {
            auto col = low_cardinality_type->createColumn();
            static_cast<ColumnLowCardinality &>(*col).insertRangeFromFullColumn(*column, 0, column->size());
            return std::move(col);
        }
    }

    if (const auto * from_array_type = typeid_cast<const DataTypeArray *>(from_type.get()))
    {
        if (const auto * to_array_type = typeid_cast<const DataTypeArray *>(to_type.get()))
        {
            const auto * column_array = typeid_cast<const ColumnArray *>(column.get());
            if (!column_array)
                throw Exception("Unexpected column " + column->getName() + " for type " + from_type->getName(),
                                ErrorCodes::ILLEGAL_COLUMN);

            auto & nested_from = from_array_type->getNestedType();
            auto & nested_to = to_array_type->getNestedType();

            return ColumnArray::create(
                    recursiveLowCardinalityConversion(column_array->getDataPtr(), nested_from, nested_to),
                    column_array->getOffsetsPtr());
        }
    }

    if (const auto * from_tuple_type = typeid_cast<const DataTypeTuple *>(from_type.get()))
    {
        if (const auto * to_tuple_type = typeid_cast<const DataTypeTuple *>(to_type.get()))
        {
            const auto * column_tuple = typeid_cast<const ColumnTuple *>(column.get());
            if (!column_tuple)
                throw Exception("Unexpected column " + column->getName() + " for type " + from_type->getName(),
                                ErrorCodes::ILLEGAL_COLUMN);

            Columns columns = column_tuple->getColumns();
            auto & from_elements = from_tuple_type->getElements();
            auto & to_elements = to_tuple_type->getElements();
            for (size_t i = 0; i < columns.size(); ++i)
            {
                auto & element = columns[i];
                element = recursiveLowCardinalityConversion(element, from_elements.at(i), to_elements.at(i));
            }
            return ColumnTuple::create(columns);
        }
    }

    throw Exception("Cannot convert: " + from_type->getName() + " to " + to_type->getName(), ErrorCodes::TYPE_MISMATCH);
}