void fwt(unsigned int N, unsigned int flags, const long shifts[N], const long dims[N], complex float* out, const long istr[N], const complex float* in, const long minsize[N], long flen, const float filter[2][2][flen]) { if (0 == flags) { if (out != in) md_copy2(N, dims, istr, out, istr, in, CFL_SIZE); return; } unsigned long coeffs = wavelet_coeffs(N, flags, dims, minsize, flen); long wdims[2 * N]; wavelet_dims(N, flags, wdims, dims, flen); long ostr[2 * N]; md_calc_strides(2 * N, ostr, wdims, CFL_SIZE); long offset = coeffs - md_calc_size(2 * N, wdims); debug_printf(DP_DEBUG4, "%d %ld %ld\n", flags, coeffs, offset); long shifts0[N]; for (unsigned int i = 0; i < N; i++) shifts0[i] = 0; fwtN(N, flags, shifts, dims, ostr, out + offset, istr, in, flen, filter); fwt(N, wavelet_filter_flags(N, flags, wdims, minsize), shifts0, wdims, out, ostr, out + offset, minsize, flen, filter); }
void fwt2(unsigned int N, unsigned int flags, const long shifts[N], const long odims[N], const long ostr[N], complex float* out, const long idims[N], const long istr[N], const complex float* in, const long minsize[N], long flen, const float filter[2][2][flen]) { assert(wavelet_check_dims(N, flags, idims, minsize)); if (0 == flags) { // note: recursion does *not* end here assert(md_check_compat(N, 0u, odims, idims)); md_copy2(N, idims, ostr, out, istr, in, CFL_SIZE); return; } // check output dimensions long odims2[N]; wavelet_coeffs2(N, flags, odims2, idims, minsize, flen); assert(md_check_compat(N, 0u, odims2, odims)); long wdims2[2 * N]; wavelet_dims(N, flags, wdims2, idims, flen); // only consider transform dims... long dims1[N]; md_select_dims(N, flags, dims1, idims); long wdims[2 * N]; wavelet_dims(N, flags, wdims, dims1, flen); long level_coeffs = md_calc_size(2 * N, wdims); // ... which get embedded in dimension b unsigned int b = ffs(flags) - 1; long ostr2[2 * N]; md_calc_strides(2 * N, ostr2, wdims, ostr[b]); // merge with original strides for (unsigned int i = 0; i < N; i++) if (!MD_IS_SET(flags, i)) ostr2[i] = ostr[i]; assert(odims[b] >= level_coeffs); long offset = (odims[b] - level_coeffs) * (ostr[b] / CFL_SIZE); long bands = md_calc_size(N, wdims + N); long coeffs = md_calc_size(N, wdims + 0); debug_printf(DP_DEBUG4, "fwt2: flags:%d lcoeffs:%ld coeffs:%ld (space:%ld) bands:%ld str:%ld off:%ld\n", flags, level_coeffs, coeffs, odims2[b], bands, ostr[b], offset / istr[b]); // subtract coefficients in high band odims2[b] -= (bands - 1) * coeffs; assert(odims2[b] > 0); long shifts0[N]; for (unsigned int i = 0; i < N; i++) shifts0[i] = 0; unsigned int flags2 = wavelet_filter_flags(N, flags, wdims, minsize); assert((0 == offset) == (0u == flags2)); fwtN(N, flags, shifts, idims, ostr2, out + offset, istr, in, flen, filter); if (0 != flags2) { long odims3[N]; wavelet_coeffs2(N, flags2, odims3, wdims2, minsize, flen); long ostr3[N]; embed(N, flags, ostr3, odims3, ostr); fwt2(N, flags2, shifts0, odims3, ostr3, out, wdims2, ostr2, out + offset, minsize, flen, filter); } }