diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 79e3347662..7cb64d7570 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1049,6 +1049,7 @@ std::vector SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr new_make_tuple_inputs.emplace_back(new_partial); } auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs); + new_make_tuple->set_abstract(make_tuple_node->abstract()); switch_layer_inputs.emplace_back(new_make_tuple); auto new_switch_layer = graph->NewCNode(switch_layer_inputs); cnode_inputs.emplace_back(new_switch_layer); diff --git a/tests/st/control/test_switch_layer.py b/tests/st/control/test_switch_layer.py index e62c0584d4..3f33b94b82 100644 --- a/tests/st/control/test_switch_layer.py +++ b/tests/st/control/test_switch_layer.py @@ -38,6 +38,8 @@ class CaseNet(nn.Cell): @pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_switch_layer():