|
|
|
@ -60,127 +60,23 @@ std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, co
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) {
|
|
|
|
|
#if ENABLE_D
|
|
|
|
|
std::vector<PrimitivePtr> basic_ops = {
|
|
|
|
|
prim::kPrimAddN, prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, prim::kPrimMaximum,
|
|
|
|
|
prim::kPrimMinimum, prim::kPrimNeg, prim::kPrimRealDiv, prim::kPrimPow, prim::kPrimSqrt,
|
|
|
|
|
prim::kPrimExpandDims, prim::kPrimReciprocal, prim::kPrimLessEqual};
|
|
|
|
|
if (!is_before_kernel_select) {
|
|
|
|
|
basic_ops.push_back(prim::kPrimCast);
|
|
|
|
|
}
|
|
|
|
|
#elif ENABLE_GPU
|
|
|
|
|
std::vector<PrimitivePtr> basic_ops = GetFusibleOpList();
|
|
|
|
|
#else
|
|
|
|
|
std::vector<PrimitivePtr> basic_ops;
|
|
|
|
|
#endif
|
|
|
|
|
return std::any_of(basic_ops.begin(), basic_ops.end(),
|
|
|
|
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsReduceOp(const AnfNodePtr &node) {
|
|
|
|
|
std::vector<PrimitivePtr> reduce_ops = {prim::kPrimReduceSum, prim::kPrimReduceMean, prim::kPrimReduceMin,
|
|
|
|
|
prim::kPrimReduceMax, prim::kPrimReduceAll};
|
|
|
|
|
return std::any_of(reduce_ops.begin(), reduce_ops.end(),
|
|
|
|
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GetGraphKernelInfo(const FuncGraphPtr &fg, GraphKernelInfo *info) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fg);
|
|
|
|
|
auto mng = fg->manager();
|
|
|
|
|
if (mng == nullptr) {
|
|
|
|
|
mng = Manage(fg, false);
|
|
|
|
|
fg->set_manager(mng);
|
|
|
|
|
}
|
|
|
|
|
const auto &nodes = fg->nodes();
|
|
|
|
|
info->op_type = ELEWISE;
|
|
|
|
|
info->cal_step = -1;
|
|
|
|
|
info->reduce_op_num = 0;
|
|
|
|
|
for (auto node : nodes) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (cnode == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
info->cal_step++;
|
|
|
|
|
if (IsReduceOp(node)) {
|
|
|
|
|
info->op_type = REDUCE;
|
|
|
|
|
info->reduce_op_num++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto fg_flag = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
|
|
|
|
if (fg_flag != nullptr) {
|
|
|
|
|
auto fg_name = GetValue<std::string>(fg_flag);
|
|
|
|
|
info->origin_composite_name = fg_name;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsCompositeFuseBasic(const GraphKernelInfo &info, const AnfNodePtr &node) {
|
|
|
|
|
#if ENABLE_D
|
|
|
|
|
std::vector<PrimitivePtr> fusable_with_reduce;
|
|
|
|
|
if (!info.is_before_kernel_select) {
|
|
|
|
|
fusable_with_reduce.push_back(prim::kPrimCast);
|
|
|
|
|
}
|
|
|
|
|
if (info.op_type == REDUCE &&
|
|
|
|
|
(info.cal_step >= MAX_REDUCE_OP_FUSION_CAL_STEP || info.reduce_op_num >= MAX_REDUCE_OP_FUSION_REDUCE_NUM)) {
|
|
|
|
|
return std::any_of(fusable_with_reduce.begin(), fusable_with_reduce.end(),
|
|
|
|
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
return IsBasicFuseOp(node, info.is_before_kernel_select);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsFuse(const GraphKernelInfo &info, const AnfNodePtr &node) {
|
|
|
|
|
bool IsFuse(const AnfNodePtr &node) {
|
|
|
|
|
// composite fuse composite op
|
|
|
|
|
if (AnfAlgo::IsGraphKernel(node)) {
|
|
|
|
|
#if ENABLE_D
|
|
|
|
|
return false;
|
|
|
|
|
#else
|
|
|
|
|
return true;
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
return IsCompositeFuseBasic(info, node);
|
|
|
|
|
return IsBasicFuseOp(node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void UpdateGraphKernelInfo(GraphKernelInfo *info, const AnfNodePtr &node) {
|
|
|
|
|
if (IsPrimitiveCNode(node)) {
|
|
|
|
|
info->cal_step++;
|
|
|
|
|
if (IsReduceOp(node)) {
|
|
|
|
|
info->op_type = REDUCE;
|
|
|
|
|
}
|
|
|
|
|
info->origin_composite_name += AnfAlgo::GetCNodePrimitive(node)->name() + "_";
|
|
|
|
|
} else if (AnfAlgo::IsGraphKernel(node)) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto composite_g = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
|
|
|
|
GraphKernelInfo fuse_info;
|
|
|
|
|
GetGraphKernelInfo(composite_g, &fuse_info);
|
|
|
|
|
info->cal_step += fuse_info.cal_step;
|
|
|
|
|
info->origin_composite_name += fuse_info.origin_composite_name;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, GraphKernelInfo *info, const AnfNodePtr &node) {
|
|
|
|
|
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
|
|
|
|
|
if (cur_node == node) {
|
|
|
|
|
return FOLLOW;
|
|
|
|
|
}
|
|
|
|
|
#if ENABLE_D
|
|
|
|
|
if (!IsPrimitiveCNode(node)) {
|
|
|
|
|
return EXCLUDE;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
bool is_fuse_composite = AnfAlgo::IsGraphKernel(node);
|
|
|
|
|
if (!IsPrimitiveCNode(node) && !is_fuse_composite) {
|
|
|
|
|
return EXCLUDE;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
bool is_fusable = IsFuse(*info, node);
|
|
|
|
|
if (is_fusable) {
|
|
|
|
|
UpdateGraphKernelInfo(info, node);
|
|
|
|
|
}
|
|
|
|
|
bool is_fusable = IsFuse(node);
|
|
|
|
|
return is_fusable ? FOLLOW : EXCLUDE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelInfo *info, const AnfNodePtr &node) {
|
|
|
|
|
IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
|
|
|
|
|
if (cur_node == node) {
|
|
|
|
|
return FOLLOW;
|
|
|
|
|
}
|
|
|
|
@ -195,14 +91,7 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI
|
|
|
|
|
}
|
|
|
|
|
return EXCLUDE;
|
|
|
|
|
}
|
|
|
|
|
if (!IsPrimitiveCNode(node)) {
|
|
|
|
|
return EXCLUDE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool is_fusable = IsFuse(*info, node);
|
|
|
|
|
if (is_fusable) {
|
|
|
|
|
UpdateGraphKernelInfo(info, node);
|
|
|
|
|
}
|
|
|
|
|
bool is_fusable = IsBasicFuseOp(node);
|
|
|
|
|
return is_fusable ? FOLLOW : EXCLUDE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -350,19 +239,15 @@ void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
|
|
|
|
|
lst->assign(res.begin(), res.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) {
|
|
|
|
|
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
|
|
|
|
|
auto func_graph = cnode->func_graph();
|
|
|
|
|
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
|
|
|
|
GraphKernelInfo info;
|
|
|
|
|
info.is_before_kernel_select = is_before_kernel_select;
|
|
|
|
|
GetGraphKernelInfo(graph_kernel_g, &info);
|
|
|
|
|
auto mng = func_graph->manager();
|
|
|
|
|
// Search fusable nodes according input direction.
|
|
|
|
|
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, &info, std::placeholders::_1);
|
|
|
|
|
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
|
|
|
|
|
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
|
|
|
|
|
std::reverse(used_nodes.begin(), used_nodes.end());
|
|
|
|
|
// Search fusable nodes according output direction.
|
|
|
|
|
auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, &info, std::placeholders::_1);
|
|
|
|
|
auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, std::placeholders::_1);
|
|
|
|
|
auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng);
|
|
|
|
|
|
|
|
|
|
used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end());
|
|
|
|
@ -373,7 +258,7 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_ker
|
|
|
|
|
return used_nodes;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select) {
|
|
|
|
|
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
bool changed = false;
|
|
|
|
|
auto &todos = kernel_graph->execution_order();
|
|
|
|
@ -392,19 +277,19 @@ bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select);
|
|
|
|
|
auto fuse_nodes = FindFuseCNodes(node);
|
|
|
|
|
if (fuse_nodes.size() <= 1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
changed = true;
|
|
|
|
|
|
|
|
|
|
FuseNodesToSubGraph(fuse_nodes, kernel_graph, "", is_before_kernel_select);
|
|
|
|
|
FuseNodesToSubGraph(fuse_nodes, kernel_graph, "");
|
|
|
|
|
}
|
|
|
|
|
return changed;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
return FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph), false);
|
|
|
|
|
return FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph));
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|