!7022 [AutoParallel]Fix find loss and root reshape bug

Merge pull request !7022 from lichen/fix_auto_parallel_find_loss_bug
pull/7022/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b9df01b60e

@ -1852,14 +1852,14 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
if (pre_cnode == nullptr) {
return loss_node_info;
}
pre_cnode = HandleDependLoss(pre_cnode);
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
auto prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
// return -> cast
if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
if (prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
}
pre_cnode = HandleDependLoss(pre_cnode);
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
// notice: the GetNext op has not input
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
@ -2416,6 +2416,12 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG
// shape op doesn't have params and attrs.
OperatorParams params;
OperatorAttrs attrs;
auto shape_value = GetValueNode(node->input(2))->cast<ValueSequeuePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
auto shape = shape_value->value();
if (shape.empty()) {
return;
}
OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(SHAPE_OP, args);
InsertNode(op, node, 2, pre_node, root, "shape");

Loading…
Cancel
Save