void backward_convolutional_layer(convolutional_layer l, network_state state) { int i; int m = l.n; int n = l.size*l.size*l.c; int k = convolutional_out_height(l)* convolutional_out_width(l); gradient_array(l.output, m*k*l.batch, l.activation, l.delta); backward_bias(l.bias_updates, l.delta, l.batch, l.n, k); for(i = 0; i < l.batch; ++i){ float *a = l.delta + i*m*k; float *b = l.col_image; float *c = l.filter_updates; float *im = state.input+i*l.c*l.h*l.w; im2col_cpu(im, l.c, l.h, l.w, l.size, l.stride, l.pad, b); gemm(0,1,m,n,k,1,a,k,b,k,1,c,n); if(state.delta){ a = l.filters; b = l.delta + i*m*k; c = l.col_image; gemm(1,0,n,k,m,1,a,n,b,k,0,c,k); col2im_cpu(l.col_image, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta+i*l.c*l.h*l.w); } } }
void backward_batchnorm_layer(const layer l, network_state state) { backward_bias(l.bias_updates, l.delta, l.batch, l.out_c, l.out_w*l.out_h); backward_scale_cpu(l.x_norm, l.delta, l.batch, l.out_c, l.out_w*l.out_h, l.scale_updates); scale_bias(l.delta, l.scales, l.batch, l.out_c, l.out_h*l.out_w); mean_delta_cpu(l.delta, l.variance, l.batch, l.out_c, l.out_w*l.out_h, l.mean_delta); variance_delta_cpu(l.x, l.delta, l.mean, l.variance, l.batch, l.out_c, l.out_w*l.out_h, l.variance_delta); normalize_delta_cpu(l.x, l.mean, l.variance, l.mean_delta, l.variance_delta, l.batch, l.out_c, l.out_w*l.out_h, l.delta); if(l.type == BATCHNORM) copy_cpu(l.outputs*l.batch, l.delta, 1, state.delta, 1); }
void backward_deconvolutional_layer(layer l, network_state state) { float alpha = 1./l.batch; int out_h = deconvolutional_out_height(l); int out_w = deconvolutional_out_width(l); int size = out_h*out_w; int i; gradient_array(l.output, size*l.n*l.batch, l.activation, l.delta); if(l.batch_normalize){ backward_batchnorm_layer(l, state); } else { backward_bias(l.bias_updates, l.delta, l.batch, l.n, l.out_w*l.out_h); } for(i = 0; i < l.batch; ++i){ int m = l.c; int n = l.size*l.size*l.n; int k = l.h*l.w; float *a = state.input + i*m*n; float *b = state.workspace; float *c = l.weight_updates; im2col_cpu(l.delta + i*l.n*size, l.n, out_h, out_w, l.size, l.stride, 0, b); gemm(0,1,m,n,k,alpha,a,k,b,k,1,c,n); if(state.delta){ int m = l.c; int n = l.h*l.w; int k = l.size*l.size*l.n; float *a = l.weights; float *b = state.workspace; float *c = state.delta + i*n*m; gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); } } }
void backward_deconvolutional_layer(layer l, network net) { int i; gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta); if(l.batch_normalize){ backward_batchnorm_layer(l, net); } else { backward_bias(l.bias_updates, l.delta, l.batch, l.n, l.out_w*l.out_h); } //if(net.delta) memset(net.delta, 0, l.batch*l.h*l.w*l.c*sizeof(float)); for(i = 0; i < l.batch; ++i){ int m = l.c; int n = l.size*l.size*l.n; int k = l.h*l.w; float *a = net.input + i*m*k; float *b = net.workspace; float *c = l.weight_updates; im2col_cpu(l.delta + i*l.outputs, l.out_c, l.out_h, l.out_w, l.size, l.stride, l.pad, b); gemm_cpu(0,1,m,n,k,1,a,k,b,k,1,c,n); if(net.delta){ int m = l.c; int n = l.h*l.w; int k = l.size*l.size*l.n; float *a = l.weights; float *b = net.workspace; float *c = net.delta + i*n*m; gemm_cpu(0,0,m,n,k,1,a,k,b,n,1,c,n); } } }
void backward_convolutional_layer(convolutional_layer l, network net) { int i, j; int m = l.n/l.groups; int n = l.size*l.size*l.c/l.groups; int k = l.out_w*l.out_h; gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta); if(l.batch_normalize){ backward_batchnorm_layer(l, net); } else { backward_bias(l.bias_updates, l.delta, l.batch, l.n, k); } for(i = 0; i < l.batch; ++i){ for(j = 0; j < l.groups; ++j){ float *a = l.delta + (i*l.groups + j)*m*k; float *b = net.workspace; float *c = l.weight_updates + j*l.nweights/l.groups; float *im = net.input+(i*l.groups + j)*l.c/l.groups*l.h*l.w; im2col_cpu(im, l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, b); gemm(0,1,m,n,k,1,a,k,b,k,1,c,n); if(net.delta){ a = l.weights + j*l.nweights/l.groups; b = l.delta + (i*l.groups + j)*m*k; c = net.workspace; gemm(1,0,n,k,m,1,a,n,b,k,0,c,k); col2im_cpu(net.workspace, l.c/l.groups, l.h, l.w, l.size, l.stride, l.pad, net.delta + (i*l.groups + j)*l.c/l.groups*l.h*l.w); } } } }