Image FFTConvolve::apply(Window im, Window filter, Convolve::BoundaryCondition b, Multiply::Mode m) { int resultChannels = 0; // check the number of channels is correct if (m == Multiply::Inner) { assert(im.channels % filter.channels == 0 || filter.channels % im.channels == 0, "For inner-product convolution either the image must have a number of" " channels that is a multiple of the number of channels in the filter," " or vice-versa.\n"); resultChannels = max(im.channels / filter.channels, filter.channels / im.channels); } else if (m == Multiply::Outer) { // anything goes resultChannels = im.channels * filter.channels; } else if (m == Multiply::Elementwise) { assert(im.channels == filter.channels, "For elementwise convolution the filter must have the same number of channels as the image\n"); resultChannels = im.channels; } else { panic("Unknown channel mode: %d\n", m); } // Deal with the homogeneous case recursively. This is slightly // inefficient because we construct and transform the filter // twice, but it makes the code much simpler if (b == Convolve::Homogeneous) { Image result = apply(im, filter, Convolve::Zero, m); Image weight(im.width, im.height, im.frames, im.channels); Offset::apply(weight, 1.0f); Image resultW = apply(weight, filter, Convolve::Zero, m); Divide::apply(result, resultW); return result; } assert(filter.width % 2 == 1 && filter.height % 2 == 1 && filter.frames % 2 == 1, "The filter must have odd dimensions\n"); int xPad = filter.width/2; int yPad = filter.height/2; int tPad = filter.frames/2; if (b == Convolve::Wrap) { xPad = yPad = tPad = 0; } Image imT; Image weightT; imT = Image(im.width+xPad*2, im.height+yPad*2, im.frames+tPad*2, im.channels*2); //printf("1\n"); fflush(stdout); // 1) Make the padded complex image if (b == Convolve::Clamp) { for (int t = 0; t < imT.frames; t++) { int st = clamp(t-tPad, 0, im.frames-1); for (int y = 0; y < imT.height; y++) { int sy = clamp(y-yPad, 0, im.height-1); float *imTPtr = imT(0, y, t); float *imPtr = im(0, sy, st); for (int x = 0; x < imT.width; x++) { int sx = clamp(x-xPad, 0, im.width-1); for (int c = 0; c < im.channels; c++) { *imTPtr++ = imPtr[sx*im.channels+c]; *imTPtr++ = 0; } } } } } else { // Zero or Wrap for (int t = 0; t < im.frames; t++) { for (int y = 0; y < im.height; y++) { float *imPtr = im(0, y, t); float *imTPtr = imT(xPad, y+yPad, t+tPad); for (int x = 0; x < im.width; x++) { for (int c = 0; c < im.channels; c++) { *imTPtr++ = *imPtr++; imTPtr++; } } } } } //printf("2\n"); fflush(stdout); // 2) Transform the padded image FFT::apply(imT); //printf("3\n"); fflush(stdout); // 3) Make a padded complex filter of the same size Image filterT(imT.width, imT.height, imT.frames, filter.channels*2); for (int t = 0; t < filter.frames; t++) { int ft = t - filter.frames/2; if (ft < 0) ft += filterT.frames; for (int y = 0; y < filter.height; y++) { int fy = y - filter.height/2; if (fy < 0) fy += filterT.height; for (int x = 0; x < filter.width; x++) { for (int c = 0; c < filter.channels; c++) { int fx = x - filter.width/2; if (fx < 0) fx += filterT.width; filterT(fx, fy, ft)[2*c] = filter(x, y, t)[c]; } } } } //printf("4\n"); fflush(stdout); // 4) Transform the padded filter FFT::apply(filterT); //printf("5\n"); fflush(stdout); // 5) Multiply the two into a padded complex transformed result Image resultT(imT.width, imT.height, imT.frames, resultChannels*2); for (int t = 0; t < resultT.frames; t++) { for (int y = 0; y < resultT.height; y++) { float *resultTPtr = resultT(0, y, t); float *filterTPtr = filterT(0, y, t); float *imTPtr = imT(0, y, t); if (m == Multiply::Outer) { for (int x = 0; x < resultT.width; x++) { for (int cf = 0; cf < filterT.channels; cf+=2) { for (int ci = 0; ci < imT.channels; ci+=2) { *resultTPtr++ = filterTPtr[cf]*imTPtr[ci] - filterTPtr[cf+1]*imTPtr[ci+1]; *resultTPtr++ = filterTPtr[cf+1]*imTPtr[ci] + filterTPtr[cf]*imTPtr[ci+1]; } } imTPtr += imT.channels; filterTPtr += filterT.channels; } } else if (m == Multiply::Inner && filter.channels > im.channels) { for (int x = 0; x < resultT.width; x++) { for (int cr = 0; cr < resultChannels; cr++) { for (int ci = 0; ci < imT.channels; ci+=2) { resultTPtr[0] += filterTPtr[0]*imTPtr[ci] - filterTPtr[1]*imTPtr[ci+1]; resultTPtr[1] += filterTPtr[1]*imTPtr[ci] + filterTPtr[0]*imTPtr[ci+1]; filterTPtr += 2; } resultTPtr += 2; } imTPtr += imT.channels; } } else if (m == Multiply::Inner) { for (int x = 0; x < resultT.width; x++) { for (int cr = 0; cr < resultChannels; cr++) { for (int cf = 0; cf < filterT.channels; cf+=2) { resultTPtr[0] += filterTPtr[cf]*imTPtr[0] - filterTPtr[cf+1]*imTPtr[1]; resultTPtr[1] += filterTPtr[cf+1]*imTPtr[0] + filterTPtr[cf]*imTPtr[1]; imTPtr += 2; } resultTPtr += 2; } filterTPtr += filterT.channels; } } else { // m == ELEMENTWISE for (int x = 0; x < resultT.width; x++) { for (int c = 0; c < resultChannels; c++) { resultTPtr[0] += filterTPtr[0]*imTPtr[0] - filterTPtr[1]*imTPtr[1]; resultTPtr[1] += filterTPtr[1]*imTPtr[0] + filterTPtr[0]*imTPtr[1]; imTPtr += 2; resultTPtr += 2; filterTPtr += 2; } } } } } //printf("6\n"); fflush(stdout); // 6) Inverse transorm the result IFFT::apply(resultT); //printf("7\n"); fflush(stdout); // 7) Remove the padding, and convert back to real numbers Image result(im.width, im.height, im.frames, resultChannels); for (int t = 0; t < im.frames; t++) { for (int y = 0; y < im.height; y++) { float *resultPtr = result(0, y, t); float *resultTPtr = resultT(xPad, y+yPad, t+tPad); for (int x = 0; x < im.width; x++) { for (int c = 0; c < resultChannels; c++) { *resultPtr++ = *resultTPtr++; // skip the imaginary part resultTPtr++; } } } } //printf("8\n"); fflush(stdout); return result; }
void FFTConvolve::convolveSingle(Image im, Image filter, Image out, Convolve::BoundaryCondition b) { // Deal with the homogeneous case recursively. This is slightly // inefficient because we construct and transform the filter // twice, but it makes the code much simpler if (b == Convolve::Homogeneous) { Image result = apply(im, filter, Convolve::Zero, Multiply::Outer); Image weight(im.width, im.height, im.frames, 1); weight.set(1.0f); Image resultW = apply(weight, filter, Convolve::Zero, Multiply::Outer); out += Stats(filter).sum() * result / resultW; return; } assert(filter.width % 2 == 1 && filter.height % 2 == 1 && filter.frames % 2 == 1, "The filter must have odd dimensions\n"); int xPad = filter.width/2; int yPad = filter.height/2; int tPad = filter.frames/2; if (b == Convolve::Wrap) { xPad = yPad = tPad = 0; } Image weightT; Image imT = Image(im.width+xPad*2, im.height+yPad*2, im.frames+tPad*2, 2); //printf("1\n"); fflush(stdout); // 1) Make the padded complex image if (b == Convolve::Clamp) { for (int t = 0; t < imT.frames; t++) { int st = clamp(t-tPad, 0, im.frames-1); for (int y = 0; y < imT.height; y++) { int sy = clamp(y-yPad, 0, im.height-1); for (int x = 0; x < imT.width; x++) { int sx = clamp(x-xPad, 0, im.width-1); imT(x, y, t, 0) = im(sx, sy, st, 0); } } } } else { // Zero or Wrap imT.region(xPad, yPad, tPad, 0, im.width, im.height, im.frames, 1).set(im); } //printf("2\n"); fflush(stdout); // 2) Transform the padded image FFT::apply(imT); //printf("3\n"); fflush(stdout); // 3) Make a padded complex filter of the same size Image filterT(imT.width, imT.height, imT.frames, 2); for (int t = 0; t < filter.frames; t++) { int ft = t - filter.frames/2; if (ft < 0) ft += filterT.frames; for (int y = 0; y < filter.height; y++) { int fy = y - filter.height/2; if (fy < 0) fy += filterT.height; for (int x = 0; x < filter.width; x++) { int fx = x - filter.width/2; if (fx < 0) fx += filterT.width; filterT(fx, fy, ft, 0) = filter(x, y, t, 0); } } } //printf("4\n"); fflush(stdout); // 4) Transform the padded filter FFT::apply(filterT); //printf("5\n"); fflush(stdout); // 5) Multiply the two into a padded complex transformed result ComplexMultiply::apply(imT, filterT); //printf("6\n"); fflush(stdout); // 6) Inverse transorm the result IFFT::apply(imT); //printf("7\n"); fflush(stdout); // 7) Remove the padding, and convert back to real numbers out += imT.region(xPad, yPad, tPad, 0, im.width, im.height, im.frames, 1); }