From e34b2873fae73acf67d6c61fb5d0f8bc7768d247 Mon Sep 17 00:00:00 2001 From: yangwei Date: Thu, 18 Mar 2021 19:14:28 +0800 Subject: [PATCH] set abstract for maketuple --- mindspore/ccsrc/backend/session/session_basic.cc | 1 + tests/st/control/test_switch_layer.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 79e3347662..7cb64d7570 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1049,6 +1049,7 @@ std::vector SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr new_make_tuple_inputs.emplace_back(new_partial); } auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs); + new_make_tuple->set_abstract(make_tuple_node->abstract()); switch_layer_inputs.emplace_back(new_make_tuple); auto new_switch_layer = graph->NewCNode(switch_layer_inputs); cnode_inputs.emplace_back(new_switch_layer); diff --git a/tests/st/control/test_switch_layer.py b/tests/st/control/test_switch_layer.py index e62c0584d4..3f33b94b82 100644 --- a/tests/st/control/test_switch_layer.py +++ b/tests/st/control/test_switch_layer.py @@ -38,6 +38,8 @@ class CaseNet(nn.Cell): @pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_switch_layer():