Array<T> shift(const Array<T> &in, const af::dim4 &sdims) { const af::dim4 iDims = in.dims(); af::dim4 oDims = iDims; Array<T> out = createEmptyArray<T>(oDims); T* outPtr = out.get(); const T* inPtr = in.get(); const af::dim4 ist = in.strides(); const af::dim4 ost = out.strides(); dim_type sdims_[4]; // Need to do this because we are mapping output to input in the kernel for(int i = 0; i < 4; i++) { // sdims_[i] will always be positive and always [0, oDims[i]]. // Negative shifts are converted to position by going the other way round sdims_[i] = -(sdims[i] % oDims[i]) + oDims[i] * (sdims[i] > 0); assert(sdims_[i] >= 0 && sdims_[i] <= oDims[i]); } for(dim_type ow = 0; ow < oDims[3]; ow++) { const dim_type oW = ow * ost[3]; const dim_type iw = simple_mod((ow + sdims_[3]), oDims[3]); const dim_type iW = iw * ist[3]; for(dim_type oz = 0; oz < oDims[2]; oz++) { const dim_type oZW = oW + oz * ost[2]; const dim_type iz = simple_mod((oz + sdims_[2]), oDims[2]); const dim_type iZW = iW + iz * ist[2]; for(dim_type oy = 0; oy < oDims[1]; oy++) { const dim_type oYZW = oZW + oy * ost[1]; const dim_type iy = simple_mod((oy + sdims_[1]), oDims[1]); const dim_type iYZW = iZW + iy * ist[1]; for(dim_type ox = 0; ox < oDims[0]; ox++) { const dim_type oIdx = oYZW + ox; const dim_type ix = simple_mod((ox + sdims_[0]), oDims[0]); const dim_type iIdx = iYZW + ix; outPtr[oIdx] = inPtr[iIdx]; } } } } return out; }
__global__ void shift_kernel(Param<T> out, CParam<T> in, const int d0, const int d1, const int d2, const int d3, const int blocksPerMatX, const int blocksPerMatY) { const int oz = blockIdx.x / blocksPerMatX; const int ow = (blockIdx.y + blockIdx.z * gridDim.y) / blocksPerMatY; const int blockIdx_x = blockIdx.x - oz * blocksPerMatX; const int blockIdx_y = (blockIdx.y + blockIdx.z * gridDim.y) - ow * blocksPerMatY; const int xx = threadIdx.x + blockIdx_x * blockDim.x; const int yy = threadIdx.y + blockIdx_y * blockDim.y; if(xx >= out.dims[0] || yy >= out.dims[1] || oz >= out.dims[2] || ow >= out.dims[3]) return; const int incy = blocksPerMatY * blockDim.y; const int incx = blocksPerMatX * blockDim.x; const int iw = simple_mod((ow + d3), out.dims[3]); const int iz = simple_mod((oz + d2), out.dims[2]); const int o_off = ow * out.strides[3] + oz * out.strides[2]; const int i_off = iw * in.strides[3] + iz * in.strides[2]; for(int oy = yy; oy < out.dims[1]; oy += incy) { const int iy = simple_mod((oy + d1), out.dims[1]); for(int ox = xx; ox < out.dims[0]; ox += incx) { const int ix = simple_mod((ox + d0), out.dims[0]); const int oIdx = o_off + oy * out.strides[1] + ox; const int iIdx = i_off + iy * in.strides[1] + ix; out.ptr[oIdx] = in.ptr[iIdx]; } } }