示例#1
0
    Array<T> shift(const Array<T> &in, const af::dim4 &sdims)
    {
        const af::dim4 iDims = in.dims();
        af::dim4 oDims = iDims;

        Array<T> out = createEmptyArray<T>(oDims);

        T* outPtr = out.get();
        const T* inPtr = in.get();

        const af::dim4 ist = in.strides();
        const af::dim4 ost = out.strides();

        dim_type sdims_[4];
        // Need to do this because we are mapping output to input in the kernel
        for(int i = 0; i < 4; i++) {
            // sdims_[i] will always be positive and always [0, oDims[i]].
            // Negative shifts are converted to position by going the other way round
            sdims_[i] = -(sdims[i] % oDims[i]) + oDims[i] * (sdims[i] > 0);
            assert(sdims_[i] >= 0 && sdims_[i] <= oDims[i]);
        }

        for(dim_type ow = 0; ow < oDims[3]; ow++) {
            const dim_type oW = ow * ost[3];
            const dim_type iw = simple_mod((ow + sdims_[3]), oDims[3]);
            const dim_type iW = iw * ist[3];
            for(dim_type oz = 0; oz < oDims[2]; oz++) {
                const dim_type oZW = oW + oz * ost[2];
                const dim_type iz = simple_mod((oz + sdims_[2]), oDims[2]);
                const dim_type iZW = iW + iz * ist[2];
                for(dim_type oy = 0; oy < oDims[1]; oy++) {
                    const dim_type oYZW = oZW + oy * ost[1];
                    const dim_type iy = simple_mod((oy + sdims_[1]), oDims[1]);
                    const dim_type iYZW = iZW + iy * ist[1];
                    for(dim_type ox = 0; ox < oDims[0]; ox++) {
                        const dim_type oIdx = oYZW + ox;
                        const dim_type ix = simple_mod((ox + sdims_[0]), oDims[0]);
                        const dim_type iIdx = iYZW + ix;

                        outPtr[oIdx] = inPtr[iIdx];
                    }
                }
            }
        }

        return out;
    }
示例#2
0
        __global__
        void shift_kernel(Param<T> out, CParam<T> in, const int d0, const int d1,
                            const int d2, const int d3,
                            const int blocksPerMatX, const int blocksPerMatY)
        {
            const int oz = blockIdx.x / blocksPerMatX;
            const int ow = (blockIdx.y + blockIdx.z * gridDim.y) / blocksPerMatY;

            const int blockIdx_x = blockIdx.x - oz * blocksPerMatX;
            const int blockIdx_y = (blockIdx.y + blockIdx.z * gridDim.y) - ow * blocksPerMatY;

            const int xx = threadIdx.x + blockIdx_x * blockDim.x;
            const int yy = threadIdx.y + blockIdx_y * blockDim.y;

            if(xx >= out.dims[0] ||
               yy >= out.dims[1] ||
               oz >= out.dims[2] ||
               ow >= out.dims[3])
                return;

            const int incy = blocksPerMatY * blockDim.y;
            const int incx = blocksPerMatX * blockDim.x;

            const int iw = simple_mod((ow + d3), out.dims[3]);
            const int iz = simple_mod((oz + d2), out.dims[2]);

            const int o_off = ow * out.strides[3] + oz * out.strides[2];
            const int i_off = iw *  in.strides[3] + iz *  in.strides[2];

            for(int oy = yy; oy < out.dims[1]; oy += incy) {
                const int iy = simple_mod((oy + d1), out.dims[1]);
                for(int ox = xx; ox < out.dims[0]; ox += incx) {
                    const int ix = simple_mod((ox + d0), out.dims[0]);

                    const int oIdx = o_off + oy * out.strides[1] + ox;
                    const int iIdx = i_off + iy *  in.strides[1] + ix;

                    out.ptr[oIdx] = in.ptr[iIdx];
                }
            }
        }