VBM3D_Process_Base::Pos3PairCode VBM3D_Process_Base::BlockMatching( const std::vector<const FLType *> &ref, PCType j, PCType i) const { // Skip block matching if GroupSize is 1 or thMSE is not positive, // and take the reference block as the only element in the group if (d.para.GroupSize == 1 || d.para.thMSE <= 0) { return Pos3PairCode(1, Pos3Pair(KeyType(0), Pos3Type(cur, j, i))); } int f; Pos3PairCode matchCode; PosPairCode frameMatch; PosCode prePosCode; // Get reference block from the reference plane in current frame block_type refBlock(ref[cur], ref_stride[0], d.para.BlockSize, d.para.BlockSize, PosType(j, i)); // Block Matching in current frame f = cur; frameMatch = refBlock.BlockMatchingMulti(ref[f], ref_height[0], ref_width[0], ref_stride[0], FLType(1), d.para.BMrange, d.para.BMstep, d.para.thMSE, 1, d.para.GroupSize, true); matchCode.resize(matchCode.size() + frameMatch.size()); std::transform(frameMatch.begin(), frameMatch.end(), matchCode.end() - frameMatch.size(), [&](const PosPair &x) { return Pos3Pair(x.first, Pos3Type(x.second, f)); }); PCType nextPosNum = Min(d.para.PSnum, static_cast<PCType>(frameMatch.size())); PosCode curPosCode(nextPosNum); std::transform(frameMatch.begin(), frameMatch.begin() + nextPosNum, curPosCode.begin(), [](const PosPair &x) { return x.second; }); PosCode curSearchPos = refBlock.GenSearchPos(curPosCode, ref_height[0], ref_width[0], d.para.PSrange, d.para.PSstep); // Predictive Search Block Matching in backward frames f = cur - 1; for (; f >= 0; --f) { if (f == cur - 1) { frameMatch = refBlock.BlockMatchingMulti(ref[f], ref_stride[0], FLType(1), curSearchPos, d.para.thMSE, d.para.GroupSize, true); } else { PCType nextPosNum = Min(d.para.PSnum, static_cast<PCType>(frameMatch.size())); prePosCode.resize(nextPosNum); std::transform(frameMatch.begin(), frameMatch.begin() + nextPosNum, prePosCode.begin(), [](const PosPair &x) { return x.second; }); PosCode searchPos = refBlock.GenSearchPos(prePosCode, ref_height[0], ref_width[0], d.para.PSrange, d.para.PSstep); frameMatch = refBlock.BlockMatchingMulti(ref[f], ref_stride[0], FLType(1), searchPos, d.para.thMSE, d.para.GroupSize, true); } matchCode.resize(matchCode.size() + frameMatch.size()); std::transform(frameMatch.begin(), frameMatch.end(), matchCode.end() - frameMatch.size(), [&](const PosPair &x) { return Pos3Pair(x.first, Pos3Type(x.second, f)); }); } // Predictive Search Block Matching in forward frames f = cur + 1; for (; f < frames; ++f) { if (f == cur + 1) { frameMatch = refBlock.BlockMatchingMulti(ref[f], ref_stride[0], FLType(1), curSearchPos, d.para.thMSE, d.para.GroupSize, true); } else { PCType nextPosNum = Min(d.para.PSnum, static_cast<PCType>(frameMatch.size())); prePosCode.resize(nextPosNum); std::transform(frameMatch.begin(), frameMatch.begin() + nextPosNum, prePosCode.begin(), [](const PosPair &x) { return x.second; }); PosCode searchPos = refBlock.GenSearchPos(prePosCode, ref_height[0], ref_width[0], d.para.PSrange, d.para.PSstep); frameMatch = refBlock.BlockMatchingMulti(ref[f], ref_stride[0], FLType(1), searchPos, d.para.thMSE, d.para.GroupSize, true); } matchCode.resize(matchCode.size() + frameMatch.size()); std::transform(frameMatch.begin(), frameMatch.end(), matchCode.end() - frameMatch.size(), [&](const PosPair &x) { return Pos3Pair(x.first, Pos3Type(x.second, f)); }); } // Limit the number of matched code to GroupSize if (d.para.GroupSize > 0 && static_cast<PCType>(matchCode.size()) > d.para.GroupSize) { std::partial_sort(matchCode.begin() + 1, matchCode.begin() + d.para.GroupSize, matchCode.end()); matchCode.resize(d.para.GroupSize); } return matchCode; }
void BM3D_Basic_Process::CollaborativeFilter(int plane, FLType *ResNum, FLType *ResDen, const FLType *src, const FLType *ref, const PosPairCode &code) const { PCType GroupSize = static_cast<PCType>(code.size()); // When para.GroupSize > 0, limit GroupSize up to para.GroupSize if (d.para.GroupSize > 0 && GroupSize > d.para.GroupSize) { GroupSize = d.para.GroupSize; } // Construct source group guided by matched pos code block_group srcGroup(src, src_stride[plane], code, GroupSize, d.para.BlockSize, d.para.BlockSize); // Initialize retianed coefficients of hard threshold filtering int retainedCoefs = 0; // Apply forward 3D transform to the source group d.f[plane].fp[GroupSize - 1].execute_r2r(srcGroup.data(), srcGroup.data()); // Apply hard-thresholding to the source group auto srcp = srcGroup.data(); auto thrp = d.f[plane].thrTable[GroupSize - 1].get(); const auto upper = srcp + srcGroup.size(); #if defined(__SSE2__) static const ptrdiff_t simd_step = 4; const ptrdiff_t simd_residue = srcGroup.size() % simd_step; const ptrdiff_t simd_width = srcGroup.size() - simd_residue; static const __m128 zero_ps = _mm_setzero_ps(); __m128i cmp_sum = _mm_setzero_si128(); for (const auto upper1 = srcp + simd_width; srcp < upper1; srcp += simd_step, thrp += simd_step) { const __m128 s1 = _mm_load_ps(srcp); const __m128 t1p = _mm_load_ps(thrp); const __m128 t1n = _mm_sub_ps(zero_ps, t1p); const __m128 cmp1 = _mm_cmpgt_ps(s1, t1p); const __m128 cmp2 = _mm_cmplt_ps(s1, t1n); const __m128 cmp = _mm_or_ps(cmp1, cmp2); const __m128 d1 = _mm_and_ps(cmp, s1); _mm_store_ps(srcp, d1); cmp_sum = _mm_sub_epi32(cmp_sum, _mm_castps_si128(cmp)); } alignas(16) int32_t cmp_sum_i32[4]; _mm_store_si128(reinterpret_cast<__m128i *>(cmp_sum_i32), cmp_sum); retainedCoefs += cmp_sum_i32[0] + cmp_sum_i32[1] + cmp_sum_i32[2] + cmp_sum_i32[3]; #endif for (; srcp < upper; ++srcp, ++thrp) { if (*srcp > *thrp || *srcp < -*thrp) { ++retainedCoefs; } else { *srcp = 0; } } // Apply backward 3D transform to the filtered group d.f[plane].bp[GroupSize - 1].execute_r2r(srcGroup.data(), srcGroup.data()); // Calculate weight for the filtered group // Also include the normalization factor to compensate for the amplification introduced in 3D transform FLType denWeight = retainedCoefs < 1 ? 1 : FLType(1) / static_cast<FLType>(retainedCoefs); FLType numWeight = static_cast<FLType>(denWeight / d.f[plane].finalAMP[GroupSize - 1]); // Store the weighted filtered group to the numerator part of the basic estimation // Store the weight to the denominator part of the basic estimation srcGroup.AddTo(ResNum, dst_stride[plane], numWeight); srcGroup.CountTo(ResDen, dst_stride[plane], denWeight); }
void BM3D_Final_Process::CollaborativeFilter(int plane, FLType *ResNum, FLType *ResDen, const FLType *src, const FLType *ref, const PosPairCode &code) const { PCType GroupSize = static_cast<PCType>(code.size()); // When para.GroupSize > 0, limit GroupSize up to para.GroupSize if (d.para.GroupSize > 0 && GroupSize > d.para.GroupSize) { GroupSize = d.para.GroupSize; } // Construct source group and reference group guided by matched pos code block_group srcGroup(src, src_stride[plane], code, GroupSize, d.para.BlockSize, d.para.BlockSize); block_group refGroup(ref, ref_stride[plane], code, GroupSize, d.para.BlockSize, d.para.BlockSize); // Initialize L2-norm of Wiener coefficients FLType L2Wiener = 0; // Apply forward 3D transform to the source group and the reference group d.f[plane].fp[GroupSize - 1].execute_r2r(srcGroup.data(), srcGroup.data()); d.f[plane].fp[GroupSize - 1].execute_r2r(refGroup.data(), refGroup.data()); // Apply empirical Wiener filtering to the source group guided by the reference group const FLType sigmaSquare = d.f[plane].wienerSigmaSqr[GroupSize - 1]; auto srcp = srcGroup.data(); auto refp = refGroup.data(); const auto upper = srcp + srcGroup.size(); #if defined(__SSE2__) static const ptrdiff_t simd_step = 4; const ptrdiff_t simd_residue = srcGroup.size() % simd_step; const ptrdiff_t simd_width = srcGroup.size() - simd_residue; const __m128 sgm_sqr = _mm_set_ps1(sigmaSquare); __m128 l2wiener_sum = _mm_setzero_ps(); for (const auto upper1 = srcp + simd_width; srcp < upper1; srcp += simd_step, refp += simd_step) { const __m128 s1 = _mm_load_ps(srcp); const __m128 r1 = _mm_load_ps(refp); const __m128 r1sqr = _mm_mul_ps(r1, r1); const __m128 wiener = _mm_mul_ps(r1sqr, _mm_rcp_ps(_mm_add_ps(r1sqr, sgm_sqr))); const __m128 d1 = _mm_mul_ps(s1, wiener); _mm_store_ps(srcp, d1); l2wiener_sum = _mm_add_ps(l2wiener_sum, _mm_mul_ps(wiener, wiener)); } alignas(16) FLType l2wiener_sum_f32[4]; _mm_store_ps(l2wiener_sum_f32, l2wiener_sum); L2Wiener += l2wiener_sum_f32[0] + l2wiener_sum_f32[1] + l2wiener_sum_f32[2] + l2wiener_sum_f32[3]; #endif for (; srcp < upper; ++srcp, ++refp) { const FLType refSquare = *refp * *refp; const FLType wienerCoef = refSquare / (refSquare + sigmaSquare); *srcp *= wienerCoef; L2Wiener += wienerCoef * wienerCoef; } // Apply backward 3D transform to the filtered group d.f[plane].bp[GroupSize - 1].execute_r2r(srcGroup.data(), srcGroup.data()); // Calculate weight for the filtered group // Also include the normalization factor to compensate for the amplification introduced in 3D transform FLType denWeight = L2Wiener <= 0 ? 1 : FLType(1) / L2Wiener; FLType numWeight = static_cast<FLType>(denWeight / d.f[plane].finalAMP[GroupSize - 1]); // Store the weighted filtered group to the numerator part of the final estimation // Store the weight to the denominator part of the final estimation srcGroup.AddTo(ResNum, dst_stride[plane], numWeight); srcGroup.CountTo(ResDen, dst_stride[plane], denWeight); }