void backward_local_layer_gpu(local_layer l, network_state state) { int i, j; int locations = l.out_w*l.out_h; gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu, state.st_handle.stream); for(i = 0; i < l.batch; ++i) { axpy_ongpu(l.outputs, 1, l.delta_gpu + i*l.outputs, 1, l.bias_updates_gpu, 1, state.st_handle.stream); } for(i = 0; i < l.batch; ++i) { float *input = state.input + i*l.w*l.h*l.c; im2col_ongpu(input, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu, state.st_handle.stream); for(j = 0; j < locations; ++j) { float *a = l.delta_gpu + i*l.outputs + j; float *b = l.col_image_gpu + j; float *c = l.weight_updates_gpu + j*l.size*l.size*l.c*l.n; int m = l.n; int n = l.size*l.size*l.c; int k = 1; //printf("passou no backward_local_layer_gpu first call\n"); gemm_ongpu(0,1,m,n,k,1,a,locations,b,locations,1,c,n, state.st_handle); } if(state.delta) { for(j = 0; j < locations; ++j) { float *a = l.weights_gpu + j*l.size*l.size*l.c*l.n; float *b = l.delta_gpu + i*l.outputs + j; float *c = l.col_image_gpu + j; int m = l.size*l.size*l.c; int n = 1; int k = l.n; //printf("passou no backward_local_layer_gpu second call\n"); gemm_ongpu(1,0,m,n,k,1,a,m,b,locations,0,c,locations, state.st_handle); } col2im_ongpu(l.col_image_gpu, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta+i*l.c*l.h*l.w, state.st_handle.stream); } } }
void backward_local_layer_gpu(local_layer l, network_state state) { int i, j; int locations = l.out_w*l.out_h; gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu); for(i = 0; i < l.batch; ++i){ axpy_ongpu(l.outputs, 1, l.delta_gpu + i*l.outputs, 1, l.bias_updates_gpu, 1); } for(i = 0; i < l.batch; ++i){ float *input = state.input + i*l.w*l.h*l.c; im2col_ongpu(input, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu); for(j = 0; j < locations; ++j){ float *a = l.delta_gpu + i*l.outputs + j; float *b = l.col_image_gpu + j; float *c = l.filter_updates_gpu + j*l.size*l.size*l.c*l.n; int m = l.n; int n = l.size*l.size*l.c; int k = 1; gemm_ongpu(0,1,m,n,k,1,a,locations,b,locations,1,c,n); } if(state.delta){ for(j = 0; j < locations; ++j){ float *a = l.filters_gpu + j*l.size*l.size*l.c*l.n; float *b = l.delta_gpu + i*l.outputs + j; float *c = l.col_image_gpu + j; int m = l.size*l.size*l.c; int n = 1; int k = l.n; gemm_ongpu(1,0,m,n,k,1,a,m,b,locations,0,c,locations); } col2im_ongpu(l.col_image_gpu, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta+i*l.c*l.h*l.w); } } }