Expr lossless_cast(Type t, Expr e) { if (t == e.type()) { return e; } else if (t.can_represent(e.type())) { return cast(t, e); } if (const Cast *c = e.as<Cast>()) { if (t.can_represent(c->value.type())) { // We can recurse into widening casts. return lossless_cast(t, c->value); } else { return Expr(); } } if (const Broadcast *b = e.as<Broadcast>()) { Expr v = lossless_cast(t.element_of(), b->value); if (v.defined()) { return Broadcast::make(v, b->lanes); } else { return Expr(); } } if (const IntImm *i = e.as<IntImm>()) { if (t.can_represent(i->value)) { return make_const(t, i->value); } else { return Expr(); } } if (const UIntImm *i = e.as<UIntImm>()) { if (t.can_represent(i->value)) { return make_const(t, i->value); } else { return Expr(); } } if (const FloatImm *f = e.as<FloatImm>()) { if (t.can_represent(f->value)) { return make_const(t, f->value); } else { return Expr(); } } return Expr(); }
void CodeGen_X86::visit(const Cast *op) { if (!op->type.is_vector()) { // We only have peephole optimizations for vectors in here. CodeGen_Posix::visit(op); return; } vector<Expr> matches; struct Pattern { Target::Feature feature; bool wide_op; Type type; int min_lanes; string intrin; Expr pattern; }; static Pattern patterns[] = { {Target::AVX2, true, Int(8, 32), 0, "llvm.x86.avx2.padds.b", i8_sat(wild_i16x_ + wild_i16x_)}, {Target::FeatureEnd, true, Int(8, 16), 0, "llvm.x86.sse2.padds.b", i8_sat(wild_i16x_ + wild_i16x_)}, {Target::AVX2, true, Int(8, 32), 0, "llvm.x86.avx2.psubs.b", i8_sat(wild_i16x_ - wild_i16x_)}, {Target::FeatureEnd, true, Int(8, 16), 0, "llvm.x86.sse2.psubs.b", i8_sat(wild_i16x_ - wild_i16x_)}, #if LLVM_VERSION < 80 // Older LLVM versions support this as an intrinsic {Target::AVX2, true, UInt(8, 32), 0, "llvm.x86.avx2.paddus.b", u8_sat(wild_u16x_ + wild_u16x_)}, {Target::FeatureEnd, true, UInt(8, 16), 0, "llvm.x86.sse2.paddus.b", u8_sat(wild_u16x_ + wild_u16x_)}, {Target::AVX2, true, UInt(8, 32), 0, "llvm.x86.avx2.psubus.b", u8(max(wild_i16x_ - wild_i16x_, 0))}, {Target::FeatureEnd, true, UInt(8, 16), 0, "llvm.x86.sse2.psubus.b", u8(max(wild_i16x_ - wild_i16x_, 0))}, #else // LLVM 8.0+ require using helpers from x86.ll {Target::AVX2, true, UInt(8, 32), 0, "paddusbx32", u8_sat(wild_u16x_ + wild_u16x_)}, {Target::FeatureEnd, true, UInt(8, 16), 0, "paddusbx16", u8_sat(wild_u16x_ + wild_u16x_)}, {Target::AVX2, true, UInt(8, 32), 0, "psubusbx32", u8(max(wild_i16x_ - wild_i16x_, 0))}, {Target::FeatureEnd, true, UInt(8, 16), 0, "psubusbx16", u8(max(wild_i16x_ - wild_i16x_, 0))}, #endif {Target::AVX2, true, Int(16, 16), 0, "llvm.x86.avx2.padds.w", i16_sat(wild_i32x_ + wild_i32x_)}, {Target::FeatureEnd, true, Int(16, 8), 0, "llvm.x86.sse2.padds.w", i16_sat(wild_i32x_ + wild_i32x_)}, {Target::AVX2, true, Int(16, 16), 0, "llvm.x86.avx2.psubs.w", i16_sat(wild_i32x_ - wild_i32x_)}, {Target::FeatureEnd, true, Int(16, 8), 0, "llvm.x86.sse2.psubs.w", i16_sat(wild_i32x_ - wild_i32x_)}, #if LLVM_VERSION < 80 // Older LLVM versions support this as an intrinsic {Target::AVX2, true, UInt(16, 16), 0, "llvm.x86.avx2.paddus.w", u16_sat(wild_u32x_ + wild_u32x_)}, {Target::FeatureEnd, true, UInt(16, 8), 0, "llvm.x86.sse2.paddus.w", u16_sat(wild_u32x_ + wild_u32x_)}, {Target::AVX2, true, UInt(16, 16), 0, "llvm.x86.avx2.psubus.w", u16(max(wild_i32x_ - wild_i32x_, 0))}, {Target::FeatureEnd, true, UInt(16, 8), 0, "llvm.x86.sse2.psubus.w", u16(max(wild_i32x_ - wild_i32x_, 0))}, #else // LLVM 8.0+ require using helpers from x86.ll {Target::AVX2, true, UInt(16, 16), 0, "padduswx16", u16_sat(wild_u32x_ + wild_u32x_)}, {Target::FeatureEnd, true, UInt(16, 8), 0, "padduswx8", u16_sat(wild_u32x_ + wild_u32x_)}, {Target::AVX2, true, UInt(16, 16), 0, "psubuswx16", u16(max(wild_i32x_ - wild_i32x_, 0))}, {Target::FeatureEnd, true, UInt(16, 8), 0, "psubuswx8", u16(max(wild_i32x_ - wild_i32x_, 0))}, #endif // Only use the avx2 version if we have > 8 lanes {Target::AVX2, true, Int(16, 16), 9, "llvm.x86.avx2.pmulh.w", i16((wild_i32x_ * wild_i32x_) / 65536)}, {Target::AVX2, true, UInt(16, 16), 9, "llvm.x86.avx2.pmulhu.w", u16((wild_u32x_ * wild_u32x_) / 65536)}, {Target::FeatureEnd, true, Int(16, 8), 0, "llvm.x86.sse2.pmulh.w", i16((wild_i32x_ * wild_i32x_) / 65536)}, {Target::FeatureEnd, true, UInt(16, 8), 0, "llvm.x86.sse2.pmulhu.w", u16((wild_u32x_ * wild_u32x_) / 65536)}, // LLVM 6.0+ require using helpers from x86.ll {Target::AVX2, true, UInt(8, 32), 0, "pavgbx32", u8(((wild_u16x_ + wild_u16x_) + 1) / 2)}, {Target::FeatureEnd, true, UInt(8, 16), 0, "pavgbx16", u8(((wild_u16x_ + wild_u16x_) + 1) / 2)}, {Target::AVX2, true, UInt(16, 16), 0, "pavgwx16", u16(((wild_u32x_ + wild_u32x_) + 1) / 2)}, {Target::FeatureEnd, true, UInt(16, 8), 0, "pavgwx8", u16(((wild_u32x_ + wild_u32x_) + 1) / 2)}, {Target::AVX2, false, Int(16, 16), 0, "packssdwx16", i16_sat(wild_i32x_)}, {Target::FeatureEnd, false, Int(16, 8), 0, "packssdwx8", i16_sat(wild_i32x_)}, {Target::AVX2, false, Int(8, 32), 0, "packsswbx32", i8_sat(wild_i16x_)}, {Target::FeatureEnd, false, Int(8, 16), 0, "packsswbx16", i8_sat(wild_i16x_)}, {Target::AVX2, false, UInt(8, 32), 0, "packuswbx32", u8_sat(wild_i16x_)}, {Target::FeatureEnd, false, UInt(8, 16), 0, "packuswbx16", u8_sat(wild_i16x_)}, {Target::AVX2, false, UInt(16, 16), 0, "packusdwx16", u16_sat(wild_i32x_)}, {Target::SSE41, false, UInt(16, 8), 0, "packusdwx8", u16_sat(wild_i32x_)} }; for (size_t i = 0; i < sizeof(patterns)/sizeof(patterns[0]); i++) { const Pattern &pattern = patterns[i]; if (!target.has_feature(pattern.feature)) { continue; } if (op->type.lanes() < pattern.min_lanes) { continue; } if (expr_match(pattern.pattern, op, matches)) { bool match = true; if (pattern.wide_op) { // Try to narrow the matches to the target type. for (size_t i = 0; i < matches.size(); i++) { matches[i] = lossless_cast(op->type, matches[i]); if (!matches[i].defined()) match = false; } } if (match) { value = call_intrin(op->type, pattern.type.lanes(), pattern.intrin, matches); return; } } } // Workaround for https://llvm.org/bugs/show_bug.cgi?id=24512 // LLVM uses a numerically unstable method for vector // uint32->float conversion before AVX. if (op->value.type().element_of() == UInt(32) && op->type.is_float() && op->type.is_vector() && !target.has_feature(Target::AVX)) { Type signed_type = Int(32, op->type.lanes()); // Convert the top 31 bits to float using the signed version Expr top_bits = cast(signed_type, op->value / 2); top_bits = cast(op->type, top_bits); // Convert the bottom bit Expr bottom_bit = cast(signed_type, op->value % 2); bottom_bit = cast(op->type, bottom_bit); // Recombine as floats codegen(top_bits + top_bits + bottom_bit); return; } CodeGen_Posix::visit(op); }
void CodeGen_X86::visit(const Cast *op) { if (!op->type.is_vector()) { // We only have peephole optimizations for vectors in here. CodeGen_Posix::visit(op); return; } vector<Expr> matches; struct Pattern { bool needs_sse_41; bool wide_op; Type type; string intrin; Expr pattern; }; static Pattern patterns[] = { {false, true, Int(8, 16), "llvm.x86.sse2.padds.b", _i8(clamp(wild_i16x_ + wild_i16x_, -128, 127))}, {false, true, Int(8, 16), "llvm.x86.sse2.psubs.b", _i8(clamp(wild_i16x_ - wild_i16x_, -128, 127))}, {false, true, UInt(8, 16), "llvm.x86.sse2.paddus.b", _u8(min(wild_u16x_ + wild_u16x_, 255))}, {false, true, UInt(8, 16), "llvm.x86.sse2.psubus.b", _u8(max(wild_i16x_ - wild_i16x_, 0))}, {false, true, Int(16, 8), "llvm.x86.sse2.padds.w", _i16(clamp(wild_i32x_ + wild_i32x_, -32768, 32767))}, {false, true, Int(16, 8), "llvm.x86.sse2.psubs.w", _i16(clamp(wild_i32x_ - wild_i32x_, -32768, 32767))}, {false, true, UInt(16, 8), "llvm.x86.sse2.paddus.w", _u16(min(wild_u32x_ + wild_u32x_, 65535))}, {false, true, UInt(16, 8), "llvm.x86.sse2.psubus.w", _u16(max(wild_i32x_ - wild_i32x_, 0))}, {false, true, Int(16, 8), "llvm.x86.sse2.pmulh.w", _i16((wild_i32x_ * wild_i32x_) / 65536)}, {false, true, UInt(16, 8), "llvm.x86.sse2.pmulhu.w", _u16((wild_u32x_ * wild_u32x_) / 65536)}, {false, true, UInt(8, 16), "llvm.x86.sse2.pavg.b", _u8(((wild_u16x_ + wild_u16x_) + 1) / 2)}, {false, true, UInt(16, 8), "llvm.x86.sse2.pavg.w", _u16(((wild_u32x_ + wild_u32x_) + 1) / 2)}, {false, false, Int(16, 8), "packssdwx8", _i16(clamp(wild_i32x_, -32768, 32767))}, {false, false, Int(8, 16), "packsswbx16", _i8(clamp(wild_i16x_, -128, 127))}, {false, false, UInt(8, 16), "packuswbx16", _u8(clamp(wild_i16x_, 0, 255))}, {true, false, UInt(16, 8), "packusdwx8", _u16(clamp(wild_i32x_, 0, 65535))} }; for (size_t i = 0; i < sizeof(patterns)/sizeof(patterns[0]); i++) { const Pattern &pattern = patterns[i]; if (!target.has_feature(Target::SSE41) && pattern.needs_sse_41) { continue; } if (expr_match(pattern.pattern, op, matches)) { bool match = true; if (pattern.wide_op) { // Try to narrow the matches to the target type. for (size_t i = 0; i < matches.size(); i++) { matches[i] = lossless_cast(op->type, matches[i]); if (!matches[i].defined()) match = false; } } if (match) { value = call_intrin(op->type, pattern.type.lanes(), pattern.intrin, matches); return; } } } #if LLVM_VERSION >= 38 // Workaround for https://llvm.org/bugs/show_bug.cgi?id=24512 // LLVM uses a numerically unstable method for vector // uint32->float conversion before AVX. if (op->value.type().element_of() == UInt(32) && op->type.is_float() && op->type.is_vector() && !target.has_feature(Target::AVX)) { Type signed_type = Int(32, op->type.lanes()); // Convert the top 31 bits to float using the signed version Expr top_bits = cast(signed_type, op->value / 2); top_bits = cast(op->type, top_bits); // Convert the bottom bit Expr bottom_bit = cast(signed_type, op->value % 2); bottom_bit = cast(op->type, bottom_bit); // Recombine as floats codegen(top_bits + top_bits + bottom_bit); return; } #endif CodeGen_Posix::visit(op); }