From 0b678d401bc43c336bdd41143151222fe99c13a1 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Wed, 28 Oct 2020 04:04:39 +0100 Subject: [PATCH] - sum (#28233) test=develop --- paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc index 414312fe97..bdff665f0f 100644 --- a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc @@ -80,8 +80,6 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { auto& input0 = in_vars[0]->Get(); in_place = (input0.numel() > 0) && (input0.data() == output_data); - MKLDNNMemoryFormat input_format = input0.format(); - for (size_t i = 0; i < in_vars.size(); i++) { auto& input_it = in_vars[i]->Get(); if (input_it.numel() == 0) { @@ -89,6 +87,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { } const T* input_data = input_it.data(); + MKLDNNMemoryFormat input_format = input_it.format(); auto src_md = memory::desc(src_tz, memory::data_type::f32, input_format); auto src_mem = memory(src_md, mkldnn_engine, to_void_cast(input_data)); @@ -115,7 +114,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { std::shared_ptr reorder_p; std::shared_ptr target_mem; if (in_place) { - output_format = input_format; + output_format = input0.format(); target_mem.reset( new memory({{src_tz}, memory::data_type::f32, output_format}, mkldnn_engine, output_data));