|
|
|
@ -1202,7 +1202,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|
|
|
|
if (node->input(index)->isa<CNode>()) {
|
|
|
|
|
auto pre_cnode = node->input(index)->cast<CNodePtr>();
|
|
|
|
|
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
|
|
|
if (pre_prim->name() == CAST) {
|
|
|
|
|
if ((pre_prim->name() == CAST) || (pre_prim->name() == LOAD)) {
|
|
|
|
|
manager->SetEdge(pre_cnode, 1, next_cnode.second);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
@ -1217,10 +1217,10 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|
|
|
|
MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size();
|
|
|
|
|
}
|
|
|
|
|
std::string instance_name = MIRROR_OP;
|
|
|
|
|
if (IsCastBeforMirror(node, index)) {
|
|
|
|
|
CNodePtr cnode = node->input(index)->cast<CNodePtr>();
|
|
|
|
|
if (IsCastBeforMirror(node, index) || (cnode != nullptr && IsSomePrimitive(cnode, LOAD))) {
|
|
|
|
|
for (auto &op : backward_op) {
|
|
|
|
|
// insert new node before the node
|
|
|
|
|
CNodePtr cnode = node->input(index)->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
AnfNodePtr pre_node = cnode->input(1);
|
|
|
|
|
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);
|
|
|
|
|