diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 74353f9ce2..f35d655820 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1852,14 +1852,14 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) { if (pre_cnode == nullptr) { return loss_node_info; } - pre_cnode = HandleDependLoss(pre_cnode); - auto current_prim = GetValueNode(pre_cnode->input(0)); + auto prim = GetValueNode(pre_cnode->input(0)); // return -> cast - if (current_prim->name() == CAST && !pre_cnode->has_user_data()) { + if (prim->name() == CAST && !pre_cnode->has_user_data()) { pre_cnode = pre_cnode->input(1)->cast(); MS_EXCEPTION_IF_NULL(pre_cnode); - current_prim = GetValueNode(pre_cnode->input(0)); } + pre_cnode = HandleDependLoss(pre_cnode); + auto current_prim = GetValueNode(pre_cnode->input(0)); // notice: the GetNext op has not input if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { @@ -2416,6 +2416,12 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG // shape op doesn't have params and attrs. OperatorParams params; OperatorAttrs attrs; + auto shape_value = GetValueNode(node->input(2))->cast(); + MS_EXCEPTION_IF_NULL(shape_value); + auto shape = shape_value->value(); + if (shape.empty()) { + return; + } OperatorArgs args = std::make_pair(attrs, params); Operator op = std::make_pair(SHAPE_OP, args); InsertNode(op, node, 2, pre_node, root, "shape");