Func repeat_edge(const Func &source, const std::vector<std::pair<Expr, Expr>> &bounds) { std::vector<Var> args(source.args()); user_assert(args.size() >= bounds.size()) << "repeat_edge called with more bounds (" << bounds.size() << ") than dimensions (" << args.size() << ") Func " << source.name() << "has.\n"; std::vector<Expr> actuals; for (size_t i = 0; i < bounds.size(); i++) { Var arg_var = args[i]; Expr min = bounds[i].first; Expr extent = bounds[i].second; if (min.defined() && extent.defined()) { actuals.push_back(clamp(likely(arg_var), min, min + extent - 1)); } else if (!min.defined() && !extent.defined()) { actuals.push_back(arg_var); } else { user_error << "Partially undefined bounds for dimension " << arg_var << " of Func " << source.name() << "\n"; } } // If there were fewer bounds than dimensions, regard the ones at the end as unbounded. actuals.insert(actuals.end(), args.begin() + actuals.size(), args.end()); Func bounded("repeat_edge"); bounded(args) = source(actuals); return bounded; }
Func constant_exterior(const Func &source, Tuple value, const std::vector<std::pair<Expr, Expr>> &bounds) { std::vector<Var> args(source.args()); user_assert(args.size() >= bounds.size()) << "constant_exterior called with more bounds (" << bounds.size() << ") than dimensions (" << source.args().size() << ") Func " << source.name() << "has.\n"; Expr out_of_bounds = cast<bool>(false); for (size_t i = 0; i < bounds.size(); i++) { Var arg_var = source.args()[i]; Expr min = bounds[i].first; Expr extent = bounds[i].second; if (min.defined() && extent.defined()) { out_of_bounds = (out_of_bounds || arg_var < min || arg_var >= min + extent); } else if (min.defined() || extent.defined()) { user_error << "Partially undefined bounds for dimension " << arg_var << " of Func " << source.name() << "\n"; } } Func bounded("constant_exterior"); if (value.as_vector().size() > 1) { std::vector<Expr> def; for (size_t i = 0; i < value.as_vector().size(); i++) { def.push_back(select(out_of_bounds, value[i], repeat_edge(source, bounds)(args)[i])); } bounded(args) = Tuple(def); } else { bounded(args) = select(out_of_bounds, value[0], repeat_edge(source, bounds)(args)); } return bounded; }
Func mirror_interior(const Func &source, const std::vector<std::pair<Expr, Expr>> &bounds) { std::vector<Var> args(source.args()); user_assert(args.size() >= bounds.size()) << "mirror_interior called with more bounds (" << bounds.size() << ") than dimensions (" << args.size() << ") Func " << source.name() << "has.\n"; std::vector<Expr> actuals; for (size_t i = 0; i < bounds.size(); i++) { Var arg_var = args[i]; Expr min = bounds[i].first; Expr extent = bounds[i].second; if (min.defined() && extent.defined()) { Expr limit = extent - 1; Expr coord = arg_var - min; // Enforce zero origin. coord = coord % (2 * limit); // Range is 0 to 2w-1 coord = coord - limit; // Range is -w, w coord = abs(coord); // Range is 0, w coord = limit - coord; // Range is 0, w coord = coord + min; // Restore correct min // The boundary condition probably doesn't apply coord = select(arg_var < min || arg_var >= min + extent, coord, clamp(likely(arg_var), min, min + extent - 1)); actuals.push_back(coord); } else if (!min.defined() && !extent.defined()) { actuals.push_back(arg_var); } else { user_error << "Partially undefined bounds for dimension " << arg_var << " of Func " << source.name() << "\n"; } } // If there were fewer bounds than dimensions, regard the ones at the end as unbounded. actuals.insert(actuals.end(), args.begin() + actuals.size(), args.end()); Func bounded("mirror_interior"); bounded(args) = source(actuals); return bounded; }
ComplexFunc fft2d_r2c(Func r, const vector<int> &R0, const vector<int> &R1, const Target& target, const Fft2dDesc& desc) { string prefix = desc.name.empty() ? "r2c_" : desc.name + "_"; vector<Var> args(r.args()); Var n0(args[0]), n1(args[1]); args.erase(args.begin()); args.erase(args.begin()); // Get the innermost variable outside the FFT. Var outer = Var::outermost(); if (!args.empty()) { outer = args.front(); } int N0 = product(R0); int N1 = product(R1); // Cache of twiddle factors for this FFT. TwiddleFactorSet twiddle_cache; // The gain requested of the FFT. Expr gain = desc.gain; // Combine pairs of real columns x, y into complex columns z = x + j y. This // allows us to compute two real DFTs using one complex FFT. See the large // comment above this function for more background. // // An implementation detail is that we zip the columns in groups from the // input data to enable the loads to be dense vectors. x is taken from the // even indexed groups columns, y is taken from the odd indexed groups of // columns. // // Changing the group size can (insignificantly) numerically change the result // due to regrouping floating point operations. To avoid this, if the FFT // description specified a vector width, use it as the group size. ComplexFunc zipped(prefix + "zipped"); int zip_width = desc.vector_width; if (zip_width <= 0) { zip_width = target.natural_vector_size(r.output_types()[0]); } // Ensure the zip width divides the zipped extent. zip_width = gcd(zip_width, N0 / 2); Expr zip_n0 = (n0 / zip_width) * zip_width * 2 + (n0 % zip_width); zipped(A({n0, n1}, args)) = ComplexExpr(r(A({zip_n0, n1}, args)), r(A({zip_n0 + zip_width, n1}, args))); // DFT down the columns first. ComplexFunc dft1 = fft_dim1(zipped, R1, -1, // sign std::min(zip_width, N0 / 2), // extent of dim 0 1.0f, false, // We parallelize unzipped below instead. prefix, target, &twiddle_cache); // Unzip the two groups of real DFTs we zipped together above. For more // information about the unzipping operation, see the large comment above this // function. ComplexFunc unzipped(prefix + "unzipped"); { Expr unzip_n0 = (n0 / (zip_width * 2)) * zip_width + (n0 % zip_width); ComplexExpr Z = dft1(A({unzip_n0, n1}, args)); ComplexExpr conjsymZ = conj(dft1(A({unzip_n0, (N1 - n1) % N1}, args))); ComplexExpr X = Z + conjsymZ; ComplexExpr Y = -j * (Z - conjsymZ); // Rather than divide the above expressions by 2 here, adjust the gain // instead. gain /= 2; unzipped(A({n0, n1}, args)) = select(n0 % (zip_width * 2) < zip_width, X, Y); } // Zip the DC and Nyquist DFT bin rows, which should be real. ComplexFunc zipped_0(prefix + "zipped_0"); zipped_0(A({n0, n1}, args)) = select(n1 > 0, likely(unzipped(A({n0, n1}, args))), ComplexExpr(re(unzipped(A({n0, 0}, args))), re(unzipped(A({n0, N1 / 2}, args))))); // The vectorization of the columns must not exceed this value. int zipped_extent0 = std::min((N1 + 1) / 2, zip_width); // transpose so we can FFT dimension 0 (by making it dimension 1). ComplexFunc unzippedT, unzippedT_tiled; std::tie(unzippedT, unzippedT_tiled) = tiled_transpose(zipped_0, zipped_extent0, target, prefix); // DFT down the columns again (the rows of the original). ComplexFunc dftT = fft_dim1(unzippedT, R0, -1, // sign zipped_extent0, gain, desc.parallel, prefix, target, &twiddle_cache); // transpose the result back to the original orientation, unless the caller // requested a transposed DFT. ComplexFunc dft = transpose(dftT); // We are going to add a row to the result (with update steps) by unzipping // the DC and Nyquist bin rows. To avoid unnecessarily computing some junk for // this row before we overwrite it, pad the pure definition with undef. dft = ComplexFunc(constant_exterior((Func)dft, Tuple(undef_z()), Expr(), Expr(), Expr(0), Expr(N1 / 2))); // Unzip the DFTs of the DC and Nyquist bin DFTs. Unzip the Nyquist DFT first, // because the DC bin DFT is updated in-place. For more information about // this, see the large comment above this function. RDom n0z1(1, N0 / 2); RDom n0z2(N0 / 2, N0 / 2); // Update 0: Unzip the DC bin of the DFT of the Nyquist bin row. dft(A({0, N1 / 2}, args)) = im(dft(A({0, 0}, args))); // Update 1: Unzip the rest of the DFT of the Nyquist bin row. dft(A({n0z1, N1 / 2}, args)) = 0.5f * -j * (dft(A({n0z1, 0}, args)) - conj(dft(A({N0 - n0z1, 0}, args)))); // Update 2: Compute the rest of the Nyquist bin row via conjugate symmetry. // Note that this redundantly computes n0 = N0/2, but that's faster and easier // than trying to deal with N0/2 - 1 bins. dft(A({n0z2, N1 / 2}, args)) = conj(dft(A({N0 - n0z2, N1 / 2}, args))); // Update 3: Unzip the DC bin of the DFT of the DC bin row. dft(A({0, 0}, args)) = re(dft(A({0, 0}, args))); // Update 4: Unzip the rest of the DFT of the DC bin row. dft(A({n0z1, 0}, args)) = 0.5f * (dft(A({n0z1, 0}, args)) + conj(dft(A({N0 - n0z1, 0}, args)))); // Update 5: Compute the rest of the DC bin row via conjugate symmetry. // Note that this redundantly computes n0 = N0/2, but that's faster and easier // than trying to deal with N0/2 - 1 bins. dft(A({n0z2, 0}, args)) = conj(dft(A({N0 - n0z2, 0}, args))); // Schedule. dftT.compute_at(dft, outer); // Schedule the tiled transposes. if (unzippedT_tiled.defined()) { unzippedT_tiled.compute_at(dftT, group); } // Schedule the input, if requested. if (desc.schedule_input) { r.compute_at(dft1, group); } // Vectorize the zip groups, and unroll by a factor of 2 to simplify the // even/odd selection. Var n0o("n0o"), n0i("n0i"); unzipped.compute_at(dft, outer) .split(n0, n0o, n0i, zip_width * 2) .reorder(n0i, n1, n0o) .vectorize(n0i, zip_width) .unroll(n0i); dft1.compute_at(unzipped, n0o); if (desc.parallel) { // Note that this also parallelizes dft1, which is computed inside this loop // of unzipped. unzipped.parallel(n0o); } // Schedule the final DFT transpose and unzipping updates. dft.vectorize(n0, target.natural_vector_size<float>()) .unroll(n0, std::min(N0 / target.natural_vector_size<float>(), 4)); // The Nyquist bin at n0z = N0/2 looks like a race condition because it // simplifies to an expression similar to the DC bin. However, we include it // in the reduction because it makes the reduction have length N/2, which is // convenient for vectorization, and just ignore the resulting appearance of // a race condition. dft.update(1).allow_race_conditions() .vectorize(n0z1, target.natural_vector_size<float>()); dft.update(2).allow_race_conditions() .vectorize(n0z2, target.natural_vector_size<float>()); dft.update(4).allow_race_conditions() .vectorize(n0z1, target.natural_vector_size<float>()); dft.update(5).allow_race_conditions() .vectorize(n0z2, target.natural_vector_size<float>()); // Our result is undefined outside these bounds. dft.bound(n0, 0, N0); dft.bound(n1, 0, (N1 + 1) / 2 + 1); return dft; }