|
|
|
@ -1852,14 +1852,14 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
|
|
|
|
|
if (pre_cnode == nullptr) {
|
|
|
|
|
return loss_node_info;
|
|
|
|
|
}
|
|
|
|
|
pre_cnode = HandleDependLoss(pre_cnode);
|
|
|
|
|
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
|
|
|
// return -> cast
|
|
|
|
|
if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
|
|
|
|
if (prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
|
|
|
|
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(pre_cnode);
|
|
|
|
|
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
|
|
|
}
|
|
|
|
|
pre_cnode = HandleDependLoss(pre_cnode);
|
|
|
|
|
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
|
|
|
|
|
|
|
|
// notice: the GetNext op has not input
|
|
|
|
|
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
|
|
|
|
@ -2416,6 +2416,12 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG
|
|
|
|
|
// shape op doesn't have params and attrs.
|
|
|
|
|
OperatorParams params;
|
|
|
|
|
OperatorAttrs attrs;
|
|
|
|
|
auto shape_value = GetValueNode(node->input(2))->cast<ValueSequeuePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(shape_value);
|
|
|
|
|
auto shape = shape_value->value();
|
|
|
|
|
if (shape.empty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
OperatorArgs args = std::make_pair(attrs, params);
|
|
|
|
|
Operator op = std::make_pair(SHAPE_OP, args);
|
|
|
|
|
InsertNode(op, node, 2, pre_node, root, "shape");
|
|
|
|
|