insert mirror before load

pull/13105/head
yangzhenzhang 4 years ago
parent ca8c07d65b
commit 6eadd241a0

@ -1202,7 +1202,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
if (node->input(index)->isa<CNode>()) {
auto pre_cnode = node->input(index)->cast<CNodePtr>();
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
if (pre_prim->name() == CAST) {
if ((pre_prim->name() == CAST) || (pre_prim->name() == LOAD)) {
manager->SetEdge(pre_cnode, 1, next_cnode.second);
continue;
}
@ -1217,10 +1217,10 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size();
}
std::string instance_name = MIRROR_OP;
if (IsCastBeforMirror(node, index)) {
CNodePtr cnode = node->input(index)->cast<CNodePtr>();
if (IsCastBeforMirror(node, index) || (cnode != nullptr && IsSomePrimitive(cnode, LOAD))) {
for (auto &op : backward_op) {
// insert new node before the node
CNodePtr cnode = node->input(index)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtr pre_node = cnode->input(1);
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);

Loading…
Cancel
Save