// Factor a float into 2^exponent * reduced, where reduced is between 0.75 and 1.5 void range_reduce_log(Expr input, Expr *reduced, Expr *exponent) { Type type = input.type(); Type int_type = Int(32, type.lanes()); Expr int_version = reinterpret(int_type, input); // single precision = SEEE EEEE EMMM MMMM MMMM MMMM MMMM MMMM // exponent mask = 0111 1111 1000 0000 0000 0000 0000 0000 // 0x7 0xF 0x8 0x0 0x0 0x0 0x0 0x0 // non-exponent = 1000 0000 0111 1111 1111 1111 1111 1111 // = 0x8 0x0 0x7 0xF 0xF 0xF 0xF 0xF Expr non_exponent_mask = make_const(int_type, 0x807fffff); // Extract a version with no exponent (between 1.0 and 2.0) Expr no_exponent = int_version & non_exponent_mask; // If > 1.5, we want to divide by two, to normalize back into the // range (0.75, 1.5). We can detect this by sniffing the high bit // of the mantissa. Expr new_exponent = no_exponent >> 22; Expr new_biased_exponent = 127 - new_exponent; Expr old_biased_exponent = int_version >> 23; *exponent = old_biased_exponent - new_biased_exponent; Expr blended = (int_version & non_exponent_mask) | (new_biased_exponent << 23); *reduced = reinterpret(type, blended); }
Expr make_zero(Type t) { if (t.is_handle()) { return reinterpret(t, make_zero(UInt(64))); } else { return make_const(t, 0); } }
Expr halide_exp(Expr x_full) { Type type = x_full.type(); internal_assert(type.element_of() == Float(32)); float ln2_part1 = 0.6931457519f; float ln2_part2 = 1.4286067653e-6f; float one_over_ln2 = 1.0f/logf(2.0f); Expr scaled = x_full * one_over_ln2; Expr k_real = floor(scaled); Expr k = cast(Int(32, type.lanes()), k_real); Expr x = x_full - k_real * ln2_part1; x -= k_real * ln2_part2; float coeff[] = { 0.00031965933071842413f, 0.00119156835564003744f, 0.00848988645943932717f, 0.04160188091348320655f, 0.16667983794100929562f, 0.49999899033463041098f, 1.0f, 1.0f }; Expr result = evaluate_polynomial(x, coeff, sizeof(coeff)/sizeof(coeff[0])); // Compute 2^k. int fpbias = 127; Expr biased = k + fpbias; Expr inf = Call::make(type, "inf_f32", {}, Call::PureExtern); // Shift the bits up into the exponent field and reinterpret this // thing as float. Expr two_to_the_n = reinterpret(type, biased << 23); result *= two_to_the_n; // Catch overflow and underflow result = select(biased < 255, result, inf); result = select(biased > 0, result, make_zero(type)); // This introduces lots of common subexpressions result = common_subexpression_elimination(result); return result; }