std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward( const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, IntList padding, IntList stride, IntList dilation, std::array<bool,3> output_mask) { Tensor grad_output = grad_output_t.contiguous(); Tensor grad_input, grad_weight, grad_bias; if (output_mask[0]) { grad_input = at::mkldnn_convolution_backward_input( input.sizes(), grad_output, weight, padding, stride, dilation, output_mask[2]); } if (output_mask[1] || output_mask[2]) { std::tie(grad_weight, grad_bias) = at::mkldnn_convolution_backward_weights( weight.sizes(), grad_output, input, padding, stride, dilation, output_mask[2]); } return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias}; }
at::Tensor mkldnn_convolution( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, IntList padding, IntList stride, IntList dilation) { auto output = input.type().tensor(conv_output_size( input.sizes(), weight.sizes(), padding, stride, dilation)); auto cpu_engine = CpuEngine::Instance().get_engine(); int32_t n = input.size(0); int32_t ic = input.size(1); int32_t ih = input.size(2); int32_t iw = input.size(3); int32_t oc = output.size(1); int32_t oh = output.size(2); int32_t ow = output.size(3); int32_t kh = weight.size(2); int32_t kw = weight.size(3); int32_t sh = stride[0]; int32_t sw = stride[1]; int32_t ph = padding[0]; int32_t pw = padding[1]; auto data_t = memory::data_type::f32; auto format_any = memory::format::any; auto format_nchw = memory::format::nchw; auto format_oihw = memory::format::oihw; auto format_x = memory::format::x; memory::dims input_tz = {n, ic, ih, iw}; memory::dims weight_tz = {oc, ic, kh, kw}; memory::dims bias_tz = {oc}; memory::dims output_tz = {n, oc, oh, ow}; memory::dims _stride = {sh, sw}; memory::dims _padding = {ph, pw}; auto input_md = memory::desc({input_tz}, data_t, format_any); auto weight_md = memory::desc({weight_tz}, data_t, format_any); auto bias_md = memory::desc({bias_tz}, data_t, format_any); auto output_md = memory::desc({output_tz}, data_t, format_any); std::shared_ptr<convolution_forward::desc> conv_forward_desc; if (bias.defined()) { conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward, convolution_direct, input_md, weight_md, bias_md, output_md, _stride, _padding, _padding, padding_kind::zero)); } else { conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward, convolution_direct, input_md, weight_md, output_md, _stride, _padding, _padding, padding_kind::zero)); } std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd; conv_forward_pd.reset(new convolution_forward::primitive_desc( *conv_forward_desc, cpu_engine)); auto input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine}, input.data_ptr()); auto weight_usr_memory = memory({{{weight_tz}, data_t, format_oihw}, cpu_engine}, weight.data_ptr()); auto output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine}, output.data_ptr()); std::vector<primitive> net; auto input_pd = conv_forward_pd->src_primitive_desc(); auto input_memory = input_usr_memory; if (input_usr_memory.get_primitive_desc() != memory::primitive_desc(input_pd)) { input_memory = memory(input_pd); net.push_back(reorder(input_usr_memory, input_memory)); } auto weight_pd = conv_forward_pd->weights_primitive_desc(); auto weight_memory = weight_usr_memory; if (weight_usr_memory.get_primitive_desc() != memory::primitive_desc(weight_pd)) { weight_memory = memory(weight_pd); net.push_back(reorder(weight_usr_memory, weight_memory)); } auto output_pd = conv_forward_pd->dst_primitive_desc(); auto output_memory = output_usr_memory; if (output_usr_memory.get_primitive_desc() != memory::primitive_desc(output_pd)) { output_memory = memory(output_pd); } std::shared_ptr<convolution_forward> conv_forward; std::shared_ptr<memory> bias_usr_memory; if (bias.defined()) { bias_usr_memory.reset(new memory({{{bias_tz}, data_t, format_x}, cpu_engine}, bias.data_ptr())); conv_forward.reset(new convolution_forward(*conv_forward_pd, input_memory, weight_memory, *bias_usr_memory, output_memory)); } else { conv_forward.reset(new convolution_forward(*conv_forward_pd, input_memory, weight_memory, output_memory)); } net.push_back(*conv_forward); if (output_memory != output_usr_memory) { net.push_back(reorder(output_memory, output_usr_memory)); } Stream::Instance().get_stream().submit(net); return output; }