示例#1
0
// Theano op code
// Authors: Arjun Jain, Frederic Bastien, Jan Schluter
// Reference code: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu
//   and https://github.com/torch/cunn/blob/master/SpatialConvolutionMM.cu
// Adaptation for 3d
PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
                           PyGpuArrayObject *const weight,
                           PyGpuArrayObject *const top,
                           const size_t direction,
                           const size_t dH = 1,
                           const size_t dW = 1,
                           const size_t dD = 1,
                           const size_t dilH = 1,
                           const size_t dilW = 1,
                           const size_t dilD = 1,
                           const size_t padH = 0,
                           const size_t padW = 0,
                           const size_t padD = 0)
{
    if (PyGpuArray_NDIM(bottom) != 5)
    {
        PyErr_SetString(PyExc_ValueError, "GpuCorr3dMM requires bottom of 5D");
        return NULL;
    }
    if (!GpuArray_IS_C_CONTIGUOUS(&bottom->ga))
    {
        PyErr_Format(PyExc_ValueError,
                "GpuCorr3dMM requires bottom to be C-contiguous, "
                "but strides are: %ld %ld %ld %ld %ld\n",
                PyGpuArray_STRIDES(bottom)[0],
                PyGpuArray_STRIDES(bottom)[1],
                PyGpuArray_STRIDES(bottom)[2],
                PyGpuArray_STRIDES(bottom)[3],
                PyGpuArray_STRIDES(bottom)[4]);
        return NULL;
    }

    if (PyGpuArray_NDIM(weight) != 5)
    {
        PyErr_SetString(PyExc_ValueError, "GpuCorr3dMM requires weight of 5D");
        return NULL;
    }
    if (!GpuArray_IS_C_CONTIGUOUS(&weight->ga))
    {
        PyErr_Format(PyExc_ValueError,
                "GpuCorr3dMM requires weight to be C-contiguous, "
                "but strides are: %ld %ld %ld %ld %ld\n",
                PyGpuArray_STRIDES(weight)[0],
                PyGpuArray_STRIDES(weight)[1],
                PyGpuArray_STRIDES(weight)[2],
                PyGpuArray_STRIDES(weight)[3],
                PyGpuArray_STRIDES(weight)[4]);
        return NULL;
    }

    if (PyGpuArray_NDIM(top) != 5)
    {
        PyErr_SetString(PyExc_ValueError, "GpuCorr3dMM requires top of 5D");
        return NULL;
    }
    if (!GpuArray_IS_C_CONTIGUOUS(&top->ga))
    {
        PyErr_Format(PyExc_ValueError,
                "GpuCorr3dMM requires top to be C-contiguous, "
                "but strides are: %ld %ld %ld %ld %ld\n",
                PyGpuArray_STRIDES(top)[0],
                PyGpuArray_STRIDES(top)[1],
                PyGpuArray_STRIDES(top)[2],
                PyGpuArray_STRIDES(top)[3],
                PyGpuArray_STRIDES(top)[4]);
        return NULL;
    }

    // Extract some shape information for later and check shape consistency
    // bottom: (batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth)
    const size_t batchSize = PyGpuArray_DIMS(bottom)[0];
    const size_t nChannels = PyGpuArray_DIMS(bottom)[1];
    const size_t bottomHeight = PyGpuArray_DIMS(bottom)[2];
    const size_t bottomWidth = PyGpuArray_DIMS(bottom)[3];
    const size_t bottomDepth = PyGpuArray_DIMS(bottom)[4];
    // weights: (nFilters, nChannels, rows, columns, slices)
    const size_t nFilters = PyGpuArray_DIMS(weight)[0];
    const size_t kH = PyGpuArray_DIMS(weight)[2];
    const size_t kW = PyGpuArray_DIMS(weight)[3];
    const size_t kD = PyGpuArray_DIMS(weight)[4];
    if (nChannels != PyGpuArray_DIMS(weight)[1]) {
        PyErr_SetString(PyExc_ValueError,
                "GpuCorr3dMM images and kernel must have the same stack size\n");
        return NULL;
    }
    // implicit dilated filter
    const size_t dil_kH = (kH - 1) * dilH + 1;
    const size_t dil_kW = (kW - 1) * dilW + 1;
    const size_t dil_kD = (kD - 1) * dilD + 1;
    // top: (batchSize, nFilters, topHeight, topWidth, topDepth)
    const size_t topHeight = (bottomHeight + 2*padH - dil_kH) / dH + 1;
    const size_t topWidth  = (bottomWidth + 2*padW - dil_kW) / dW + 1;
    const size_t topDepth  = (bottomDepth + 2*padD - dil_kD) / dD + 1;
    if (batchSize != PyGpuArray_DIMS(top)[0] ||
            nFilters != PyGpuArray_DIMS(top)[1] ||
            topHeight != PyGpuArray_DIMS(top)[2] ||
            topWidth != PyGpuArray_DIMS(top)[3] ||
            topDepth != PyGpuArray_DIMS(top)[4]) {
        PyErr_Format(PyExc_ValueError,
                "GpuCorr3dMM shape inconsistency:\n"
                "  bottom shape: %ld %ld %ld %ld %ld\n"
                "  weight shape: %ld %ld %ld %ld %ld\n"
                "  top shape: %ld %ld %ld %ld %ld (expected %ld %ld %ld %ld %ld)\n",
                batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth,
                nFilters, nChannels, kH, kW, kD,
                PyGpuArray_DIMS(top)[0], PyGpuArray_DIMS(top)[1],
                PyGpuArray_DIMS(top)[2], PyGpuArray_DIMS(top)[3], PyGpuArray_DIMS(top)[4],
                batchSize, nFilters, topHeight, topWidth, topDepth);
        return NULL;
    }

    int err = gpublas_setup(bottom->context->ctx);
    if (err != GA_NO_ERROR) {
        PyErr_SetString(PyExc_RuntimeError, "Can't setup blas");
        return NULL;
    }

    // Get the max threads per blocks
    size_t max_threads_dim;
    err = gpucontext_property(bottom->context->ctx, GA_CTX_PROP_MAXLSIZE, &max_threads_dim);
    if (err != GA_NO_ERROR){
        PyErr_Format(PyExc_RuntimeError,
                     "Could not fetch max_threads_dim.");
        return NULL;
    }

    // Create temporary columns
    size_t col_dim[2];
    col_dim[0] = nChannels * kW * kH * kD;
    col_dim[1] = topHeight * topWidth * topDepth;
    PyGpuArrayObject* col = (PyGpuArrayObject*)pygpu_empty(2, col_dim,
                                                           bottom->ga.typecode,
                                                           GA_C_ORDER,
                                                           bottom->context,
                                                           Py_None);
    if (NULL == col)
    {
        PyErr_Format(PyExc_RuntimeError,
                "GpuCorr3dMM failed to allocate working memory of %ld x %ld\n",
                col_dim[0], col_dim[1]);
        return NULL;
    }

    // Define some useful variables
    const size_t bottom_stride = PyGpuArray_STRIDES(bottom)[0] / gpuarray_get_elsize(bottom->ga.typecode);
    const size_t top_stride = PyGpuArray_STRIDES(top)[0] / gpuarray_get_elsize(top->ga.typecode);
    const size_t K_ = col_dim[0];
    const size_t N_ = col_dim[1];
    const size_t M_ = nFilters;
    const DTYPE_INPUT_0 one = 1.0f;
    const DTYPE_INPUT_0 zero = 0.0f;

    PyGpuArrayObject *output;
    if (direction == 0) {  // forward pass
        output = top;
        // valid correlation: im3d2col, then gemm
        // Iterate over batch
        for (size_t n = 0; n < batchSize; n++) {
            // First, im3d2col
            err = im3d2col(max_threads_dim,
                           bottom->ga.data, n * bottom_stride, nChannels, bottomHeight,
                           bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD,
                           padH, padW, padD, dH, dW, dD, col->ga.data);
            if (err != GA_NO_ERROR) {
                Py_DECREF(col);
                return NULL;
            }
            // Second, gemm
            err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_no_trans,
                                N_, M_, K_, one,
                                col->ga.data, 0, N_,
                                weight->ga.data, 0, K_,
                                zero,
                                top->ga.data, n * top_stride, N_);
            if (err != GA_NO_ERROR) {
                PyErr_Format(PyExc_RuntimeError,
                        "GpuCorr3dMM encountered an error running sgemm.\n");
                Py_DECREF(col);
                return NULL;
            }
        }
    }
    else if (direction == 1) {  // backprop wrt. weights
        output = weight;
        // valid convolution: im3col, then gemm
        // Iterate over batch
        for (size_t n = 0; n < batchSize; n++) {
            // First, im3d2col
            err = im3d2col(max_threads_dim,
                           bottom->ga.data, n * bottom_stride, nChannels, bottomHeight,
                           bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD,
                           padH, padW, padD, dH, dW, dD, col->ga.data);
            if (err != GA_NO_ERROR) {
                Py_DECREF(col);
                return NULL;
            }
            // Second, gemm
            // Note that we accumulate into weight. We do so by setting beta = 0
            // for the first iteration and beta = 1 for subsequent ones. (This
            // is faster than setting weight to all zeros before the loop.)
            err = gpublas_sgemm(cb_fortran, cb_trans, cb_no_trans,
                                K_, M_, N_, one,
                                col->ga.data, 0, N_,
                                top->ga.data, n * top_stride, N_,
                                (n == 0) ? zero : one,
                                weight->ga.data, 0, K_);
            if (err != GA_NO_ERROR) {
                PyErr_Format(PyExc_RuntimeError,
                        "GpuCorr3dMM encountered an error running sgemm.\n");
                Py_DECREF(col);
                return NULL;
            }
        }
    }
    else if (direction == 2) {  // backprop wrt. inputs
        output = bottom;
        // full convolution: gemm, then col2im3d
        // Iterate over batch
        for (size_t n = 0; n < batchSize; n++) {
            // gemm into columns
            err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_trans,
                                N_, K_, M_, one,
                                top->ga.data, n * top_stride, N_,
                                weight->ga.data, 0, K_,
                                zero,
                                col->ga.data, 0, N_);
            if (err != GA_NO_ERROR) {
                PyErr_Format(PyExc_RuntimeError,
                        "GpuCorr3dMM encountered an error running sgemm.\n");
                Py_DECREF(col);
                return NULL;
            }
            // col2im3d back to the data
            err = col2im3d(max_threads_dim,
                           col->ga.data, nChannels,
                           bottomHeight, bottomWidth, bottomDepth,
                           kH, kW, kD, dilH, dilW, dilD, padH, padW, padD,
                           dH, dW, dD, bottom->ga.data, n * bottom_stride);
            if (err != GA_NO_ERROR) {
                Py_DECREF(col);
                return NULL;
            }
        }
    }
    // Free temporary columns
    Py_DECREF(col);

    // Note that we don't change the refcount of the output matrix here. Output
    // (re)allocation and refcounting is done in BaseGpuCorr3dMM.c_code_helper();
    // in here output is just aliased to one of bottom, weights, or top.
    return output;
}
示例#2
0
// Theano op code
// Authors: Arjun Jain, Frederic Bastien, Jan Schluter
// Reference code: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu
//   and https://github.com/torch/cunn/blob/master/SpatialConvolutionMM.cu
// Adaptation for 3d
PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
                           PyGpuArrayObject *const weight,
                           PyGpuArrayObject *const top,
                           const size_t direction,
                           const size_t dH = 1,
                           const size_t dW = 1,
                           const size_t dD = 1,
                           const size_t dilH = 1,
                           const size_t dilW = 1,
                           const size_t dilD = 1,
                           const size_t padH = 0,
                           const size_t padW = 0,
                           const size_t padD = 0)
{
    if (PyGpuArray_NDIM(bottom) != 5)
    {
        PyErr_SetString(PyExc_ValueError, "GpuCorr3dMM requires bottom of 5D");
        return NULL;
    }
    if (!GpuArray_IS_C_CONTIGUOUS(&bottom->ga))
    {
        PyErr_Format(PyExc_ValueError,
                "GpuCorr3dMM requires bottom to be C-contiguous, "
                "but strides are: %ld %ld %ld %ld %ld\n",
                PyGpuArray_STRIDES(bottom)[0],
                PyGpuArray_STRIDES(bottom)[1],
                PyGpuArray_STRIDES(bottom)[2],
                PyGpuArray_STRIDES(bottom)[3],
                PyGpuArray_STRIDES(bottom)[4]);
        return NULL;
    }

    if (PyGpuArray_NDIM(weight) != 5)
    {
        PyErr_SetString(PyExc_ValueError, "GpuCorr3dMM requires weight of 5D");
        return NULL;
    }
    if (!GpuArray_IS_C_CONTIGUOUS(&weight->ga))
    {
        PyErr_Format(PyExc_ValueError,
                "GpuCorr3dMM requires weight to be C-contiguous, "
                "but strides are: %ld %ld %ld %ld %ld\n",
                PyGpuArray_STRIDES(weight)[0],
                PyGpuArray_STRIDES(weight)[1],
                PyGpuArray_STRIDES(weight)[2],
                PyGpuArray_STRIDES(weight)[3],
                PyGpuArray_STRIDES(weight)[4]);
        return NULL;
    }

    if (PyGpuArray_NDIM(top) != 5)
    {
        PyErr_SetString(PyExc_ValueError, "GpuCorr3dMM requires top of 5D");
        return NULL;
    }
    if (!GpuArray_IS_C_CONTIGUOUS(&top->ga))
    {
        PyErr_Format(PyExc_ValueError,
                "GpuCorr3dMM requires top to be C-contiguous, "
                "but strides are: %ld %ld %ld %ld %ld\n",
                PyGpuArray_STRIDES(top)[0],
                PyGpuArray_STRIDES(top)[1],
                PyGpuArray_STRIDES(top)[2],
                PyGpuArray_STRIDES(top)[3],
                PyGpuArray_STRIDES(top)[4]);
        return NULL;
    }

    // Extract some shape information for later and check shape consistency
    // bottom: (batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth)
    const size_t batchSize = PyGpuArray_DIMS(bottom)[0];
    const size_t nChannels = PyGpuArray_DIMS(bottom)[1];
    const size_t bottomHeight = PyGpuArray_DIMS(bottom)[2];
    const size_t bottomWidth = PyGpuArray_DIMS(bottom)[3];
    const size_t bottomDepth = PyGpuArray_DIMS(bottom)[4];
    // weights: (nFilters, nChannels, rows, columns, slices)
    const size_t nFilters = PyGpuArray_DIMS(weight)[0];
    const size_t kH = PyGpuArray_DIMS(weight)[2];
    const size_t kW = PyGpuArray_DIMS(weight)[3];
    const size_t kD = PyGpuArray_DIMS(weight)[4];
    if (nChannels != PyGpuArray_DIMS(weight)[1]) {
        PyErr_SetString(PyExc_ValueError,
                "GpuCorr3dMM images and kernel must have the same stack size\n");
        return NULL;
    }
    // implicit dilated filter
    const size_t dil_kH = (kH - 1) * dilH + 1;
    const size_t dil_kW = (kW - 1) * dilW + 1;
    const size_t dil_kD = (kD - 1) * dilD + 1;
    // top: (batchSize, nFilters, topHeight, topWidth, topDepth)
    const size_t topHeightNoDH = (bottomHeight + 2*padH - dil_kH);
    const size_t topWidthNoDW  = (bottomWidth + 2*padW - dil_kW);
    const size_t topDepthNoDD  = (bottomDepth + 2*padD - dil_kD);
    // the above values might be negative so we need to use Python-like
    // flooring integer division to be compatible with get_conv_output.
    // note: this macro implements Python's // for negative x only
#define _CONV_FLOORDIV_X(x,y) ((x < 0) ? (- ((-x) / y) - (((-x) % y) == 0 ? 0 : 1)) : (x / y))
    const size_t topHeight = _CONV_FLOORDIV_X(topHeightNoDH, dH) + 1;
    const size_t topWidth  = _CONV_FLOORDIV_X(topWidthNoDW, dW) + 1;
    const size_t topDepth  = _CONV_FLOORDIV_X(topDepthNoDD, dD) + 1;
#undef _CONV_FLOORDIV
    if (batchSize != PyGpuArray_DIMS(top)[0] ||
            nFilters != PyGpuArray_DIMS(top)[1] ||
            topHeight != PyGpuArray_DIMS(top)[2] ||
            topWidth != PyGpuArray_DIMS(top)[3] ||
            topDepth != PyGpuArray_DIMS(top)[4]) {
        PyErr_Format(PyExc_ValueError,
                "GpuCorr3dMM shape inconsistency:\n"
                "  bottom shape: %ld %ld %ld %ld %ld\n"
                "  weight shape: %ld %ld %ld %ld %ld\n"
                "  top shape: %ld %ld %ld %ld %ld (expected %ld %ld %ld %ld %ld)\n",
                batchSize, nChannels, bottomHeight, bottomWidth, bottomDepth,
                nFilters, nChannels, kH, kW, kD,
                PyGpuArray_DIMS(top)[0], PyGpuArray_DIMS(top)[1],
                PyGpuArray_DIMS(top)[2], PyGpuArray_DIMS(top)[3], PyGpuArray_DIMS(top)[4],
                batchSize, nFilters, topHeight, topWidth, topDepth);
        return NULL;
    }

    int err = gpublas_setup(bottom->context->ctx);
    if (err != GA_NO_ERROR) {
        PyErr_SetString(PyExc_RuntimeError, "Can't setup blas");
        return NULL;
    }

    // Create temporary columns
    size_t col_dim[2];
    col_dim[0] = nChannels * kW * kH * kD;
    col_dim[1] = topHeight * topWidth * topDepth;
    PyGpuArrayObject* col = (PyGpuArrayObject*)pygpu_empty(2, col_dim,
                                                           bottom->ga.typecode,
                                                           GA_C_ORDER,
                                                           bottom->context,
                                                           Py_None);
    if (NULL == col)
    {
        PyErr_Format(PyExc_RuntimeError,
                "GpuCorr3dMM failed to allocate working memory of %ld x %ld\n",
                col_dim[0], col_dim[1]);
        return NULL;
    }

    // Define some useful variables
    const size_t bottom_stride = PyGpuArray_STRIDES(bottom)[0] / gpuarray_get_elsize(bottom->ga.typecode);
    const size_t top_stride = PyGpuArray_STRIDES(top)[0] / gpuarray_get_elsize(top->ga.typecode);
    const size_t K_ = col_dim[0];
    const size_t N_ = col_dim[1];
    const size_t M_ = nFilters;

    PyGpuArrayObject *output;
    if (direction == 0) {  // forward pass
        output = top;
        if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
            err = GpuArray_memset(&output->ga, 0);
            if (err != GA_NO_ERROR) {
                PyErr_Format(PyExc_RuntimeError,
                             "GpuCorr3dMM could not fill the output with zeros: %d", err);
                Py_DECREF(col);
                return NULL;
            }
            Py_DECREF(col);
            return output;
        }
        // valid correlation: im3d2col, then gemm
        // Iterate over batch
        for (size_t n = 0; n < batchSize; n++) {
            // First, im3d2col
            err = im3d2col(
              bottom->ga.data, n * bottom_stride, nChannels, bottomHeight,
              bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD,
              padH, padW, padD, dH, dW, dD, col->ga.data);
            if (err != GA_NO_ERROR) {
                Py_DECREF(col);
                return NULL;
            }
            // Second, gemm
            switch (col->ga.typecode) {
            case GA_FLOAT:
              err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_no_trans,
                                  N_, M_, K_, 1,
                                  col->ga.data, 0, N_,
                                  weight->ga.data, 0, K_,
                                  0,
                                  top->ga.data, n * top_stride, N_);
              break;
            case GA_DOUBLE:
              err = gpublas_dgemm(cb_fortran, cb_no_trans, cb_no_trans,
                                  N_, M_, K_, 1,
                                  col->ga.data, 0, N_,
                                  weight->ga.data, 0, K_,
                                  0,
                                  top->ga.data, n * top_stride, N_);
              break;
            case GA_HALF:
              err = gpublas_hgemm(cb_fortran, cb_no_trans, cb_no_trans,
                                  N_, M_, K_, 1,
                                  col->ga.data, 0, N_,
                                  weight->ga.data, 0, K_,
                                  0,
                                  top->ga.data, n * top_stride, N_);
              break;
            default:
              err = GA_UNSUPPORTED_ERROR;
            }
            if (err != GA_NO_ERROR) {
                PyErr_Format(PyExc_RuntimeError,
                             "GpuCorr3dMM forward encountered an error running gemm.");
                Py_DECREF(col);
                return NULL;
            }
        }
    }
    else if (direction == 1) {  // backprop wrt. weights
        output = weight;
        if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
            err = GpuArray_memset(&output->ga, 0);
            if (err != GA_NO_ERROR) {
                PyErr_Format(PyExc_RuntimeError,
                             "GpuCorr3dMM grad wrt. weights could not fill the output with zeros: %d", err);
                Py_DECREF(col);
                return NULL;
            }
            Py_DECREF(col);
            return output;
        }
        // valid convolution: im3col, then gemm
        // Iterate over batch
        for (size_t n = 0; n < batchSize; n++) {
            // First, im3d2col
            err = im3d2col(
              bottom->ga.data, n * bottom_stride, nChannels, bottomHeight,
              bottomWidth, bottomDepth, kH, kW, kD, dilH, dilW, dilD,
              padH, padW, padD, dH, dW, dD, col->ga.data);
            if (err != GA_NO_ERROR) {
                Py_DECREF(col);
                return NULL;
            }
            // Second, gemm
            // Note that we accumulate into weight. We do so by setting beta = 0
            // for the first iteration and beta = 1 for subsequent ones. (This
            // is faster than setting weight to all zeros before the loop.)
            switch (col->ga.typecode) {
            case GA_FLOAT:
              err = gpublas_sgemm(cb_fortran, cb_trans, cb_no_trans,
                                  K_, M_, N_, 1,
                                  col->ga.data, 0, N_,
                                  top->ga.data, n * top_stride, N_,
                                  (n == 0) ? 0 : 1,
                                  weight->ga.data, 0, K_);
              break;
            case GA_DOUBLE:
              err = gpublas_dgemm(cb_fortran, cb_trans, cb_no_trans,
                                  K_, M_, N_, 1,
                                  col->ga.data, 0, N_,
                                  top->ga.data, n * top_stride, N_,
                                  (n == 0) ? 0 : 1,
                                  weight->ga.data, 0, K_);
              break;
            case GA_HALF:
              err = gpublas_hgemm(cb_fortran, cb_trans, cb_no_trans,
                                  K_, M_, N_, 1,
                                  col->ga.data, 0, N_,
                                  top->ga.data, n * top_stride, N_,
                                  (n == 0) ? 0 : 1,
                                  weight->ga.data, 0, K_);
              break;
            default:
              err = GA_UNSUPPORTED_ERROR;
            }
            if (err != GA_NO_ERROR) {
                PyErr_Format(PyExc_RuntimeError,
                             "GpuCorr3dMM grad weights encountered an error running gemm.");
                Py_DECREF(col);
                return NULL;
            }
        }
        if (batchSize == 0) {
            err = GpuArray_memset(&weight->ga, 0);
            if (err != GA_NO_ERROR) {
                PyErr_Format(PyExc_RuntimeError,
                             "GpuCorr3dMM grad weights could not fill the output with zeros: %d", err);
                Py_DECREF(col);
                return NULL;
            }
        }
    }
    else if (direction == 2) {  // backprop wrt. inputs
        output = bottom;
        if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
            err = GpuArray_memset(&output->ga, 0);
            if (err != GA_NO_ERROR) {
                PyErr_Format(PyExc_RuntimeError,
                             "GpuCorr3dMM grad wrt. inputs could not fill the output with zeros: %d", err);
                Py_DECREF(col);
                return NULL;
            }
            Py_DECREF(col);
            return output;
        }
        // full convolution: gemm, then col2im3d
        // Iterate over batch
        for (size_t n = 0; n < batchSize; n++) {
          // gemm into columns
          switch (top->ga.typecode) {
          case GA_FLOAT:
            err = gpublas_sgemm(cb_fortran, cb_no_trans, cb_trans,
                                N_, K_, M_, 1,
                                top->ga.data, n * top_stride, N_,
                                weight->ga.data, 0, K_,
                                0,
                                col->ga.data, 0, N_);
            break;
          case GA_DOUBLE:
            err = gpublas_dgemm(cb_fortran, cb_no_trans, cb_trans,
                                N_, K_, M_, 1,
                                top->ga.data, n * top_stride, N_,
                                weight->ga.data, 0, K_,
                                0,
                                col->ga.data, 0, N_);
            break;
          case GA_HALF:
            err = gpublas_hgemm(cb_fortran, cb_no_trans, cb_trans,
                                N_, K_, M_, 1,
                                top->ga.data, n * top_stride, N_,
                                weight->ga.data, 0, K_,
                                0,
                                col->ga.data, 0, N_);
            break;
          default:
            err = GA_UNSUPPORTED_ERROR;
          }
          if (err != GA_NO_ERROR) {
            PyErr_Format(PyExc_RuntimeError,
                         "GpuCorr3dMM grad inputs encountered an error running gemm.");
            Py_DECREF(col);
            return NULL;
          }
          // col2im3d back to the data
          err = col2im3d(col->ga.data, nChannels,
                         bottomHeight, bottomWidth, bottomDepth,
                         kH, kW, kD, dilH, dilW, dilD, padH, padW, padD,
                         dH, dW, dD, bottom->ga.data, n * bottom_stride);
          if (err != GA_NO_ERROR) {
            Py_DECREF(col);
            return NULL;
          }
        }
    }
    // Free temporary columns
    Py_DECREF(col);

    // Note that we don't change the refcount of the output matrix here. Output
    // (re)allocation and refcounting is done in BaseGpuCorr3dMM.c_code_helper();
    // in here output is just aliased to one of bottom, weights, or top.
    return output;
}