|
|
|
|
@ -75,20 +75,23 @@ void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<
|
|
|
|
|
group++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes) {
|
|
|
|
|
void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes,
|
|
|
|
|
const AnfNodePtr aggregate_node) {
|
|
|
|
|
std::vector<AnfNodePtr> inputs1 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())),
|
|
|
|
|
inplace_nodes[0].node, inplace_nodes[1].node};
|
|
|
|
|
auto control_depend_node = graph->NewCNode(inputs1);
|
|
|
|
|
|
|
|
|
|
auto return_node = graph->get_return();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(return_node);
|
|
|
|
|
// mount the `depend` before make_tuple, otherwise the output of graph will be `(tensor, )` rather than `tensor`
|
|
|
|
|
auto return_input = return_node->input(kFirstDataInputIndex)->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(return_input);
|
|
|
|
|
std::vector<AnfNodePtr> inputs2 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
|
|
|
|
return_input->input(kFirstDataInputIndex), control_depend_node};
|
|
|
|
|
aggregate_node, control_depend_node};
|
|
|
|
|
auto depend_node = graph->NewCNode(inputs2);
|
|
|
|
|
return_node->set_input(kFirstDataInputIndex, depend_node);
|
|
|
|
|
|
|
|
|
|
auto users = GetRealNodeUsedList(graph, aggregate_node);
|
|
|
|
|
if (users->size() == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "No users found: " << aggregate_node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
auto mount_node = users->at(0).first->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mount_node);
|
|
|
|
|
mount_node->set_input(kFirstDataInputIndex, depend_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node,
|
|
|
|
|
@ -186,7 +189,7 @@ bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) {
|
|
|
|
|
// 2. Set Node attr
|
|
|
|
|
SetNodeAttr(aggregate_node, skip_node, &inplace_node);
|
|
|
|
|
// 3. Set dependence for inplace nodes
|
|
|
|
|
InsertControlDependToGraph(graph, inplace_node);
|
|
|
|
|
InsertControlDependToGraph(graph, inplace_node, aggregate_node.node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|