Expr lossless_cast(Type t, Expr e) {
    if (t == e.type()) {
        return e;
    } else if (t.can_represent(e.type())) {
        return cast(t, e);
    }

    if (const Cast *c = e.as<Cast>()) {
        if (t.can_represent(c->value.type())) {
            // We can recurse into widening casts.
            return lossless_cast(t, c->value);
        } else {
            return Expr();
        }
    }

    if (const Broadcast *b = e.as<Broadcast>()) {
        Expr v = lossless_cast(t.element_of(), b->value);
        if (v.defined()) {
            return Broadcast::make(v, b->lanes);
        } else {
            return Expr();
        }
    }

    if (const IntImm *i = e.as<IntImm>()) {
        if (t.can_represent(i->value)) {
            return make_const(t, i->value);
        } else {
            return Expr();
        }
    }

    if (const UIntImm *i = e.as<UIntImm>()) {
        if (t.can_represent(i->value)) {
            return make_const(t, i->value);
        } else {
            return Expr();
        }
    }

    if (const FloatImm *f = e.as<FloatImm>()) {
        if (t.can_represent(f->value)) {
            return make_const(t, f->value);
        } else {
            return Expr();
        }
    }

    return Expr();
}
Example #2
0
void CodeGen_X86::visit(const Cast *op) {

    if (!op->type.is_vector()) {
        // We only have peephole optimizations for vectors in here.
        CodeGen_Posix::visit(op);
        return;
    }

    vector<Expr> matches;

    struct Pattern {
        Target::Feature feature;
        bool wide_op;
        Type type;
        int min_lanes;
        string intrin;
        Expr pattern;
    };

    static Pattern patterns[] = {
        {Target::AVX2, true, Int(8, 32), 0, "llvm.x86.avx2.padds.b",
         i8_sat(wild_i16x_ + wild_i16x_)},
        {Target::FeatureEnd, true, Int(8, 16), 0, "llvm.x86.sse2.padds.b",
         i8_sat(wild_i16x_ + wild_i16x_)},
        {Target::AVX2, true, Int(8, 32), 0, "llvm.x86.avx2.psubs.b",
         i8_sat(wild_i16x_ - wild_i16x_)},
        {Target::FeatureEnd, true, Int(8, 16), 0, "llvm.x86.sse2.psubs.b",
         i8_sat(wild_i16x_ - wild_i16x_)},
#if LLVM_VERSION < 80
        // Older LLVM versions support this as an intrinsic
        {Target::AVX2, true, UInt(8, 32), 0, "llvm.x86.avx2.paddus.b",
         u8_sat(wild_u16x_ + wild_u16x_)},
        {Target::FeatureEnd, true, UInt(8, 16), 0, "llvm.x86.sse2.paddus.b",
         u8_sat(wild_u16x_ + wild_u16x_)},
        {Target::AVX2, true, UInt(8, 32), 0, "llvm.x86.avx2.psubus.b",
         u8(max(wild_i16x_ - wild_i16x_, 0))},
        {Target::FeatureEnd, true, UInt(8, 16), 0, "llvm.x86.sse2.psubus.b",
         u8(max(wild_i16x_ - wild_i16x_, 0))},
#else
        // LLVM 8.0+ require using helpers from x86.ll
        {Target::AVX2, true, UInt(8, 32), 0, "paddusbx32",
         u8_sat(wild_u16x_ + wild_u16x_)},
        {Target::FeatureEnd, true, UInt(8, 16), 0, "paddusbx16",
         u8_sat(wild_u16x_ + wild_u16x_)},
        {Target::AVX2, true, UInt(8, 32), 0, "psubusbx32",
         u8(max(wild_i16x_ - wild_i16x_, 0))},
        {Target::FeatureEnd, true, UInt(8, 16), 0, "psubusbx16",
         u8(max(wild_i16x_ - wild_i16x_, 0))},
#endif
        {Target::AVX2, true, Int(16, 16), 0, "llvm.x86.avx2.padds.w",
         i16_sat(wild_i32x_ + wild_i32x_)},
        {Target::FeatureEnd, true, Int(16, 8), 0, "llvm.x86.sse2.padds.w",
         i16_sat(wild_i32x_ + wild_i32x_)},
        {Target::AVX2, true, Int(16, 16), 0, "llvm.x86.avx2.psubs.w",
         i16_sat(wild_i32x_ - wild_i32x_)},
        {Target::FeatureEnd, true, Int(16, 8), 0, "llvm.x86.sse2.psubs.w",
         i16_sat(wild_i32x_ - wild_i32x_)},
#if LLVM_VERSION < 80
        // Older LLVM versions support this as an intrinsic
        {Target::AVX2, true, UInt(16, 16), 0, "llvm.x86.avx2.paddus.w",
         u16_sat(wild_u32x_ + wild_u32x_)},
        {Target::FeatureEnd, true, UInt(16, 8), 0, "llvm.x86.sse2.paddus.w",
         u16_sat(wild_u32x_ + wild_u32x_)},
        {Target::AVX2, true, UInt(16, 16), 0, "llvm.x86.avx2.psubus.w",
         u16(max(wild_i32x_ - wild_i32x_, 0))},
        {Target::FeatureEnd, true, UInt(16, 8), 0, "llvm.x86.sse2.psubus.w",
         u16(max(wild_i32x_ - wild_i32x_, 0))},
#else
        // LLVM 8.0+ require using helpers from x86.ll
        {Target::AVX2, true, UInt(16, 16), 0, "padduswx16",
         u16_sat(wild_u32x_ + wild_u32x_)},
        {Target::FeatureEnd, true, UInt(16, 8), 0, "padduswx8",
         u16_sat(wild_u32x_ + wild_u32x_)},
        {Target::AVX2, true, UInt(16, 16), 0, "psubuswx16",
         u16(max(wild_i32x_ - wild_i32x_, 0))},
        {Target::FeatureEnd, true, UInt(16, 8), 0, "psubuswx8",
         u16(max(wild_i32x_ - wild_i32x_, 0))},
#endif
        // Only use the avx2 version if we have > 8 lanes
        {Target::AVX2, true, Int(16, 16), 9, "llvm.x86.avx2.pmulh.w",
         i16((wild_i32x_ * wild_i32x_) / 65536)},
        {Target::AVX2, true, UInt(16, 16), 9, "llvm.x86.avx2.pmulhu.w",
         u16((wild_u32x_ * wild_u32x_) / 65536)},

        {Target::FeatureEnd, true, Int(16, 8), 0, "llvm.x86.sse2.pmulh.w",
         i16((wild_i32x_ * wild_i32x_) / 65536)},
        {Target::FeatureEnd, true, UInt(16, 8), 0, "llvm.x86.sse2.pmulhu.w",
         u16((wild_u32x_ * wild_u32x_) / 65536)},
        // LLVM 6.0+ require using helpers from x86.ll
        {Target::AVX2, true, UInt(8, 32), 0, "pavgbx32",
         u8(((wild_u16x_ + wild_u16x_) + 1) / 2)},
        {Target::FeatureEnd, true, UInt(8, 16), 0, "pavgbx16",
         u8(((wild_u16x_ + wild_u16x_) + 1) / 2)},
        {Target::AVX2, true, UInt(16, 16), 0, "pavgwx16",
         u16(((wild_u32x_ + wild_u32x_) + 1) / 2)},
        {Target::FeatureEnd, true, UInt(16, 8), 0, "pavgwx8",
         u16(((wild_u32x_ + wild_u32x_) + 1) / 2)},
        {Target::AVX2, false, Int(16, 16), 0, "packssdwx16",
         i16_sat(wild_i32x_)},
        {Target::FeatureEnd, false, Int(16, 8), 0, "packssdwx8",
         i16_sat(wild_i32x_)},
        {Target::AVX2, false, Int(8, 32), 0, "packsswbx32",
         i8_sat(wild_i16x_)},
        {Target::FeatureEnd, false, Int(8, 16), 0, "packsswbx16",
         i8_sat(wild_i16x_)},
        {Target::AVX2, false, UInt(8, 32), 0, "packuswbx32",
         u8_sat(wild_i16x_)},
        {Target::FeatureEnd, false, UInt(8, 16), 0, "packuswbx16",
         u8_sat(wild_i16x_)},
        {Target::AVX2, false, UInt(16, 16), 0, "packusdwx16",
         u16_sat(wild_i32x_)},
        {Target::SSE41, false, UInt(16, 8), 0, "packusdwx8",
         u16_sat(wild_i32x_)}
    };

    for (size_t i = 0; i < sizeof(patterns)/sizeof(patterns[0]); i++) {
        const Pattern &pattern = patterns[i];

        if (!target.has_feature(pattern.feature)) {
            continue;
        }

        if (op->type.lanes() < pattern.min_lanes) {
            continue;
        }

        if (expr_match(pattern.pattern, op, matches)) {
            bool match = true;
            if (pattern.wide_op) {
                // Try to narrow the matches to the target type.
                for (size_t i = 0; i < matches.size(); i++) {
                    matches[i] = lossless_cast(op->type, matches[i]);
                    if (!matches[i].defined()) match = false;
                }
            }
            if (match) {
                value = call_intrin(op->type, pattern.type.lanes(), pattern.intrin, matches);
                return;
            }
        }
    }

    // Workaround for https://llvm.org/bugs/show_bug.cgi?id=24512
    // LLVM uses a numerically unstable method for vector
    // uint32->float conversion before AVX.
    if (op->value.type().element_of() == UInt(32) &&
        op->type.is_float() &&
        op->type.is_vector() &&
        !target.has_feature(Target::AVX)) {
        Type signed_type = Int(32, op->type.lanes());

        // Convert the top 31 bits to float using the signed version
        Expr top_bits = cast(signed_type, op->value / 2);
        top_bits = cast(op->type, top_bits);

        // Convert the bottom bit
        Expr bottom_bit = cast(signed_type, op->value % 2);
        bottom_bit = cast(op->type, bottom_bit);

        // Recombine as floats
        codegen(top_bits + top_bits + bottom_bit);
        return;
    }

    CodeGen_Posix::visit(op);
}
Example #3
0
void CodeGen_X86::visit(const Cast *op) {

    if (!op->type.is_vector()) {
        // We only have peephole optimizations for vectors in here.
        CodeGen_Posix::visit(op);
        return;
    }

    vector<Expr> matches;

    struct Pattern {
        bool needs_sse_41;
        bool wide_op;
        Type type;
        string intrin;
        Expr pattern;
    };

    static Pattern patterns[] = {
        {false, true, Int(8, 16), "llvm.x86.sse2.padds.b",
         _i8(clamp(wild_i16x_ + wild_i16x_, -128, 127))},
        {false, true, Int(8, 16), "llvm.x86.sse2.psubs.b",
         _i8(clamp(wild_i16x_ - wild_i16x_, -128, 127))},
        {false, true, UInt(8, 16), "llvm.x86.sse2.paddus.b",
         _u8(min(wild_u16x_ + wild_u16x_, 255))},
        {false, true, UInt(8, 16), "llvm.x86.sse2.psubus.b",
         _u8(max(wild_i16x_ - wild_i16x_, 0))},
        {false, true, Int(16, 8), "llvm.x86.sse2.padds.w",
         _i16(clamp(wild_i32x_ + wild_i32x_, -32768, 32767))},
        {false, true, Int(16, 8), "llvm.x86.sse2.psubs.w",
         _i16(clamp(wild_i32x_ - wild_i32x_, -32768, 32767))},
        {false, true, UInt(16, 8), "llvm.x86.sse2.paddus.w",
         _u16(min(wild_u32x_ + wild_u32x_, 65535))},
        {false, true, UInt(16, 8), "llvm.x86.sse2.psubus.w",
         _u16(max(wild_i32x_ - wild_i32x_, 0))},
        {false, true, Int(16, 8), "llvm.x86.sse2.pmulh.w",
         _i16((wild_i32x_ * wild_i32x_) / 65536)},
        {false, true, UInt(16, 8), "llvm.x86.sse2.pmulhu.w",
         _u16((wild_u32x_ * wild_u32x_) / 65536)},
        {false, true, UInt(8, 16), "llvm.x86.sse2.pavg.b",
         _u8(((wild_u16x_ + wild_u16x_) + 1) / 2)},
        {false, true, UInt(16, 8), "llvm.x86.sse2.pavg.w",
         _u16(((wild_u32x_ + wild_u32x_) + 1) / 2)},
        {false, false, Int(16, 8), "packssdwx8",
         _i16(clamp(wild_i32x_, -32768, 32767))},
        {false, false, Int(8, 16), "packsswbx16",
         _i8(clamp(wild_i16x_, -128, 127))},
        {false, false, UInt(8, 16), "packuswbx16",
         _u8(clamp(wild_i16x_, 0, 255))},
        {true, false, UInt(16, 8), "packusdwx8",
         _u16(clamp(wild_i32x_, 0, 65535))}
    };

    for (size_t i = 0; i < sizeof(patterns)/sizeof(patterns[0]); i++) {
        const Pattern &pattern = patterns[i];

        if (!target.has_feature(Target::SSE41) && pattern.needs_sse_41) {
            continue;
        }

        if (expr_match(pattern.pattern, op, matches)) {
            bool match = true;
            if (pattern.wide_op) {
                // Try to narrow the matches to the target type.
                for (size_t i = 0; i < matches.size(); i++) {
                    matches[i] = lossless_cast(op->type, matches[i]);
                    if (!matches[i].defined()) match = false;
                }
            }
            if (match) {
                value = call_intrin(op->type, pattern.type.lanes(), pattern.intrin, matches);
                return;
            }
        }
    }


    #if LLVM_VERSION >= 38
    // Workaround for https://llvm.org/bugs/show_bug.cgi?id=24512
    // LLVM uses a numerically unstable method for vector
    // uint32->float conversion before AVX.
    if (op->value.type().element_of() == UInt(32) &&
        op->type.is_float() &&
        op->type.is_vector() &&
        !target.has_feature(Target::AVX)) {
        Type signed_type = Int(32, op->type.lanes());

        // Convert the top 31 bits to float using the signed version
        Expr top_bits = cast(signed_type, op->value / 2);
        top_bits = cast(op->type, top_bits);

        // Convert the bottom bit
        Expr bottom_bit = cast(signed_type, op->value % 2);
        bottom_bit = cast(op->type, bottom_bit);

        // Recombine as floats
        codegen(top_bits + top_bits + bottom_bit);
        return;
    }
    #endif


    CodeGen_Posix::visit(op);
}