modify controldepend mount node

pull/7755/head
wilfChen 5 years ago
parent b4ce0aa933
commit e877f72bcf

@ -75,20 +75,23 @@ void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<
group++;
}
void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes) {
void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes,
const AnfNodePtr aggregate_node) {
std::vector<AnfNodePtr> inputs1 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())),
inplace_nodes[0].node, inplace_nodes[1].node};
auto control_depend_node = graph->NewCNode(inputs1);
auto return_node = graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
// mount the `depend` before make_tuple, otherwise the output of graph will be `(tensor, )` rather than `tensor`
auto return_input = return_node->input(kFirstDataInputIndex)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(return_input);
std::vector<AnfNodePtr> inputs2 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
return_input->input(kFirstDataInputIndex), control_depend_node};
aggregate_node, control_depend_node};
auto depend_node = graph->NewCNode(inputs2);
return_node->set_input(kFirstDataInputIndex, depend_node);
auto users = GetRealNodeUsedList(graph, aggregate_node);
if (users->size() == 0) {
MS_LOG(EXCEPTION) << "No users found: " << aggregate_node->DebugString();
}
auto mount_node = users->at(0).first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mount_node);
mount_node->set_input(kFirstDataInputIndex, depend_node);
}
bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node,
@ -186,7 +189,7 @@ bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) {
// 2. Set Node attr
SetNodeAttr(aggregate_node, skip_node, &inplace_node);
// 3. Set dependence for inplace nodes
InsertControlDependToGraph(graph, inplace_node);
InsertControlDependToGraph(graph, inplace_node, aggregate_node.node);
}
return true;

@ -108,6 +108,7 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>());
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
pm->AddPass(std::make_shared<opt::GetitemTuple>());
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));

Loading…
Cancel
Save