From 0368d686d3b22228a9bd854e61a226dc8ab1016a Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Tue, 15 Dec 2020 21:56:19 +0800 Subject: [PATCH] conv2d_unify_ir pass adapt for pynative --- .../ascend/mindir/conv2d_unify_mindir.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc index 3ad5166bf3..335783c6e2 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc @@ -37,6 +37,7 @@ constexpr auto kAttrPadList = "pad_list"; constexpr auto kAttrPads = "pads"; constexpr auto kAttrMode = "mode"; constexpr auto kAttrChannelMultiplier = "channel_multiplier"; +constexpr auto kAttrPerm = "perm"; bool NeedUpdate(const CNodePtr &conv2d, std::vector in_shape, std::vector out_shape) { MS_EXCEPTION_IF_NULL(conv2d); @@ -86,9 +87,16 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(conv2d); MS_EXCEPTION_IF_NULL(input_node); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); auto perm = std::vector{1, 0, 2, 3}; - std::vector transpose_inputs = {NewValueNode(std::make_shared(kTransposeOpName)), input_node, - CreatePermValueNode(graph, perm)}; + std::vector transpose_inputs; + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { + transpose_inputs = {NewValueNode(std::make_shared(kTransposeOpName)), input_node}; + } else { + transpose_inputs = {NewValueNode(std::make_shared(kTransposeOpName)), input_node, + CreatePermValueNode(graph, perm)}; + } auto transpose = graph->NewCNode(transpose_inputs); MS_EXCEPTION_IF_NULL(transpose); transpose->set_scope(conv2d->scope()); @@ -111,6 +119,9 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons auto output_names = std::vector{"output"}; AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), transpose); AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), transpose); + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { + AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), transpose); + } return transpose; }