diff --git a/paddle/fluid/operators/lrn_mkldnn_op.cc b/paddle/fluid/operators/lrn_mkldnn_op.cc index 3bead16ce4..0a18882e81 100644 --- a/paddle/fluid/operators/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/lrn_mkldnn_op.cc @@ -36,6 +36,14 @@ std::shared_ptr insert_to_context(const std::string& key, return p; } + +template +void run_primitive(Args&&... args) { + auto forward_op = mkldnn::lrn_forward{args...}; + + std::vector pipeline = {forward_op}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); +} } // namespace template @@ -87,8 +95,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine}, static_cast(output_data)}; - std::unique_ptr forward_op = nullptr; - if (!is_test) { const std::string key = ctx.op().Output("Out"); const std::string key_src_memory = key + "@lrn_src_memory"; @@ -108,9 +114,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { key_workspace_memory, dev_ctx, forward_pd->workspace_primitive_desc()); - forward_op.reset(new mkldnn::lrn_forward{*forward_pd, *src_memory, - *workspace_memory, dst_memory}); - + run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory); } else { auto forward_pd = mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine}; @@ -119,12 +123,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { auto workspace_memory = mkldnn::memory{forward_pd.workspace_primitive_desc()}; - forward_op.reset(new mkldnn::lrn_forward{forward_pd, src_memory, - workspace_memory, dst_memory}); + run_primitive(forward_pd, src_memory, workspace_memory, dst_memory); } - - std::vector pipeline = {*forward_op}; - mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); } }; @@ -136,6 +136,9 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel { "MKLDNN LRN must use float data."); PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "MKLDNN LRN must use CPUPlace."); + PADDLE_ENFORCE( + !ctx.Attr("is_test"), + "is_test attribute should be set to False in training phase."); auto x = ctx.Input("X"); diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 2b1947a187..b36b5c3a33 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -155,8 +155,8 @@ class LRNOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4."); ctx->SetOutputDim("Out", x_dim); - ctx->SetOutputDim("MidOut", x_dim); ctx->ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("MidOut", x_dim); } framework::OpKernelType GetExpectedKernelType( diff --git a/python/paddle/fluid/tests/unittests/test_lrn_op.py b/python/paddle/fluid/tests/unittests/test_lrn_op.py index 2268eafdbd..8fa480b9bc 100644 --- a/python/paddle/fluid/tests/unittests/test_lrn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lrn_op.py @@ -97,5 +97,24 @@ class TestLRNMKLDNNOp(TestLRNOp): self.check_output(atol=0.002) +class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp): + def get_attrs(self): + attrs = TestLRNMKLDNNOp.get_attrs(self) + attrs['is_test'] = True + return attrs + + def test_check_grad_normal(self): + def check_raise_is_test(): + try: + self.check_grad(['X'], 'Out', max_relative_error=0.01) + except Exception as e: + t = \ + "is_test attribute should be set to False in training phase." + if t in str(e): + raise AttributeError + + self.assertRaises(AttributeError, check_raise_is_test) + + if __name__ == "__main__": unittest.main()