diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 61d96944f7..fb1acfd713 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -438,8 +438,9 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { void CompileGraph::Push(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (slots_.count(node) > 0) { - MS_LOG(EXCEPTION) << "Push failed node in slots:" << node->DebugString() - << " NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + MS_LOG(WARNING) << "Push failed node in slots:" << node->DebugString() + << " NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + return; } MS_LOG(DEBUG) << "Push node: " << node->DebugString(true) << " height_: " << height_ << " is parameter: " << node->isa(); diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index 091e7af7bc..1356dc2fc3 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -341,13 +341,15 @@ void FinalVM::InstSwitchLayer(const VectorRef &args) { if (!backend_->GetIndex(index, &idx_value)) { MS_LOG(EXCEPTION) << "Not supported type to be casted to int."; } + auto ori_value = idx_value; if (idx_value < 0) { // Add support negative index range [-size, -1]. idx_value += size; } if (idx_value < 0 || idx_value >= size) { - MS_LOG(EXCEPTION) << __FUNCTION__ << " given index " << idx_value << " out of range. Please make sure the value " - << "of index in [" << -size << ", " << size << "), and the type is int32."; + MS_EXCEPTION(IndexError) << __FUNCTION__ << " given index " << ori_value + << " out of range. Please make sure the value " + << "of index in [" << -size << ", " << size << "), and the type is int32."; } Push(branches[idx_value]); MS_LOG(DEBUG) << "End"; diff --git a/tests/st/control/test_switch_layer.py b/tests/st/control/test_switch_layer.py index 4accb44f1a..c6af14343b 100644 --- a/tests/st/control/test_switch_layer.py +++ b/tests/st/control/test_switch_layer.py @@ -52,5 +52,5 @@ def test_switch_layer(): assert ret idx3 = Tensor(3, mstype.int32) - with pytest.raises(RuntimeError): + with pytest.raises(IndexError): value = net(data, idx3, idx2)