|
|
|
@ -72,10 +72,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto dst_md = platform::MKLDNNMemDesc(
|
|
|
|
|
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
|
|
|
|
|
|
|
|
|
|
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
|
|
|
|
|
reinterpret_cast<void*>(input_data));
|
|
|
|
|
auto weights_memory = mkldnn::memory({weights_md, mkldnn_engine},
|
|
|
|
|
reinterpret_cast<void*>(filter_data));
|
|
|
|
|
auto src_memory =
|
|
|
|
|
mkldnn::memory({src_md, mkldnn_engine},
|
|
|
|
|
reinterpret_cast<void*>(const_cast<T*>(input_data)));
|
|
|
|
|
auto weights_memory =
|
|
|
|
|
mkldnn::memory({weights_md, mkldnn_engine},
|
|
|
|
|
reinterpret_cast<void*>(const_cast<T*>(filter_data)));
|
|
|
|
|
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data);
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
|
|
|
|
@ -180,9 +182,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
|
|
|
|
|
|
|
|
|
|
// create memory
|
|
|
|
|
auto diff_dst_memory =
|
|
|
|
|
mkldnn::memory({diff_weights_md, mkldnn_engine},
|
|
|
|
|
reinterpret_cast<void*>(output_grad_data));
|
|
|
|
|
auto diff_dst_memory = mkldnn::memory(
|
|
|
|
|
{diff_weights_md, mkldnn_engine},
|
|
|
|
|
reinterpret_cast<void*>(const_cast<T*>(output_grad_data)));
|
|
|
|
|
// Retrieve conv_pd from device context
|
|
|
|
|
auto conv_pd =
|
|
|
|
|
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
|
|
|
|
@ -202,8 +204,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto diff_weights_memory =
|
|
|
|
|
mkldnn::memory({diff_weights_md, mkldnn_engine},
|
|
|
|
|
reinterpret_cast<void*>(filter_grad_data));
|
|
|
|
|
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
|
|
|
|
|
reinterpret_cast<void*>(input_data));
|
|
|
|
|
auto src_memory =
|
|
|
|
|
mkldnn::memory({src_md, mkldnn_engine},
|
|
|
|
|
reinterpret_cast<void*>(const_cast<T*>(input_data)));
|
|
|
|
|
|
|
|
|
|
// create backward conv primitive for weights
|
|
|
|
|
auto conv_bwd_weights_prim = mkldnn::convolution_backward_weights(
|
|
|
|
@ -222,11 +225,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
strides, paddings, *conv_pd, mkldnn_engine);
|
|
|
|
|
|
|
|
|
|
// create memory
|
|
|
|
|
auto diff_src_memory =
|
|
|
|
|
mkldnn::memory({diff_src_md, mkldnn_engine},
|
|
|
|
|
reinterpret_cast<void*>(input_grad_data));
|
|
|
|
|
auto weights_memory = mkldnn::memory(
|
|
|
|
|
{weights_md, mkldnn_engine}, reinterpret_cast<void*>(filter_data));
|
|
|
|
|
auto diff_src_memory = mkldnn::memory(
|
|
|
|
|
{diff_src_md, mkldnn_engine},
|
|
|
|
|
reinterpret_cast<void*>(const_cast<T*>(input_grad_data)));
|
|
|
|
|
auto weights_memory =
|
|
|
|
|
mkldnn::memory({weights_md, mkldnn_engine},
|
|
|
|
|
reinterpret_cast<void*>(const_cast<T*>(filter_data)));
|
|
|
|
|
|
|
|
|
|
// create backward conv primitive for data
|
|
|
|
|
auto conv_bwd_data_prim = mkldnn::convolution_backward_data(
|
|
|
|
|