Refactor and simplify the inlining procedure.

pull/9312/head
Zhang Qinghua 4 years ago
parent b273a46c53
commit 3ea67d4549

@ -70,46 +70,21 @@ class ReplaceApplicator : public AnfVisitor {
} }
}; };
using CriterionFuncType = std::function<bool(FuncGraphPtr, AnfNodePtr)>; class InlinerBase;
using CriterionFuncType = std::function<bool(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &)>;
bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) { bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
auto n_cnode = fg->nodes().size() - fg->parameters().size();
// There is at least one CNode(return, other_node).
return n_cnode <= 2;
}
bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) {
auto &cnodes = fg->func_graph_cnodes_index();
int64_t n_use = std::accumulate(
cnodes.begin(), cnodes.end(), 0,
[](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; });
return n_use == 1;
}
bool IsInside(FuncGraphPtr, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node->func_graph());
return node->func_graph()->has_flag("inline_inside");
}
bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); }
bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; }
bool IsDirectParentCall(FuncGraphPtr fg, AnfNodePtr node) { bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
bool unique_use = IsUniqueUse(fg, nullptr); bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node);
bool is_recursive = fg->recursive(); bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
if (fg->parent() != nullptr && is_recursive) { bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node);
if (fg->parent() == node->func_graph() && unique_use) { bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &);
return true;
}
}
return false;
}
// {G, Xs} // {G, Xs}
class InlinerBase : public AnfVisitor { class InlinerBase : public AnfVisitor {
public: public:
explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions, bool use_move = true) explicit InlinerBase(std::vector<std::vector<CriterionFuncType>> criterions, bool use_move = true)
: use_move_(use_move), criterions_(criterions) {} : use_move_(use_move), criterions_(criterions) {}
~InlinerBase() override = default; ~InlinerBase() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -135,22 +110,24 @@ class InlinerBase : public AnfVisitor {
return nullptr; return nullptr;
} }
} }
Reset(); Reset();
// 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...}
// All the criterions of 'criterion group' are true would set 'criterion group' as 'true'. As [AND].
// Anyone of 'criterion group' in 'criterions_' is 'true' would be matched. As [OR].
bool is_match = false; bool is_match = false;
for (auto &criterion : criterions_) { for (auto &criterions : criterions_) { // Each 'criterion group' in criterions_.
if (!criterion.first(fg, node)) { is_match = true;
continue; for (auto &criterion : criterions) { // Each criterion in 'criterion group'.
if (!criterion(this, fg, node)) {
is_match = false;
break;
}
} }
if (is_match) {
if (criterion.second && IsRecursive(fg)) { break;
continue;
} }
is_match = true;
break;
} }
if (!is_match) { if (!is_match) {
return nullptr; return nullptr;
} }
@ -162,24 +139,19 @@ class InlinerBase : public AnfVisitor {
if (fg->parameters().size() != args.size()) { if (fg->parameters().size() != args.size()) {
return nullptr; return nullptr;
} }
auto is_unique_use = IsUniqueUse(fg, nullptr); if (IsUniqueUse(nullptr, fg, nullptr)) {
// Not to inline after block if it has switch call inside, to avoid switch expansion. if (use_move_) {
if (!is_unique_use && fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) { auto mng = fg->manager();
auto has_branch_call = GraphHasBranch(fg); MS_EXCEPTION_IF_NULL(mng);
if (has_branch_call) { ReplaceParams(mng, args, fg);
return TransformBranchCall(fg, node, args); auto out_node = fg->output();
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
return out_node;
} }
} else if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK) && GraphHasBranch(fg)) {
// Not to inline after block if it has switch call inside, to avoid switch expansion.
return TransformBranchCall(fg, node, args);
} }
if (use_move_ && is_unique_use) {
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
ReplaceParams(mng, args, fg);
auto out_node = fg->output();
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
return out_node;
}
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope()); return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
} }
@ -208,6 +180,7 @@ class InlinerBase : public AnfVisitor {
is_checked_ = false; is_checked_ = false;
is_recursive_ = false; is_recursive_ = false;
} }
// For after block which contains branch call, delete the parameters which is not used. // For after block which contains branch call, delete the parameters which is not used.
// In most cases, it may be a `Module` or other constant input. // In most cases, it may be a `Module` or other constant input.
AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) { AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) {
@ -298,25 +271,62 @@ class InlinerBase : public AnfVisitor {
private: private:
bool is_checked_{false}, is_recursive_{false}; bool is_checked_{false}, is_recursive_{false};
bool use_move_; bool use_move_;
std::vector<std::pair<CriterionFuncType, bool>> criterions_; std::vector<std::vector<CriterionFuncType>> criterions_;
std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_; std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_;
// Key is the old func graph, and the value is the new func_graph // Key is the old func graph, and the value is the new func_graph
std::unordered_map<FuncGraphPtr, FuncGraphPtr> transformed_branch_chache_; std::unordered_map<FuncGraphPtr, FuncGraphPtr> transformed_branch_chache_;
}; };
bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
auto &cnodes = fg->func_graph_cnodes_index();
int64_t n_use = std::accumulate(
cnodes.begin(), cnodes.end(), 0,
[](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; });
return n_use == 1;
}
bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
auto n_cnode = fg->nodes().size() - fg->parameters().size();
// There is at least one CNode(return, other_node).
return n_cnode <= 2;
}
bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node->func_graph());
return node->func_graph()->has_flag("inline_inside");
}
bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { return fg->has_flag("core"); }
bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node) {
bool unique_use = IsUniqueUse(nullptr, fg, nullptr);
bool is_recursive = fg->recursive();
if (fg->parent() != nullptr && is_recursive) {
if (fg->parent() == node->func_graph() && unique_use) {
return true;
}
}
return false;
}
bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &) {
return !inliner->IsRecursive(fg);
}
class Inliner : public InlinerBase { class Inliner : public InlinerBase {
public: public:
explicit Inliner(bool use_move = true) explicit Inliner(bool use_move = true)
: InlinerBase( : InlinerBase(
// Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
{ {
{IsUniqueUse, true}, {IsTrivial},
{IsTrivial, false}, {IsInside},
{IsInside, false}, {IsCore},
{IsCore, false}, {IsNotRecursive},
{IsDirectParentCall, false}, {IsDirectParentCall},
{NoCriterion, true},
}, },
use_move) {} use_move) {}
~Inliner() override = default; ~Inliner() override = default;
}; };
@ -324,8 +334,9 @@ class DirectInliner : public InlinerBase {
public: public:
explicit DirectInliner(bool use_move = true) explicit DirectInliner(bool use_move = true)
: InlinerBase( : InlinerBase(
// Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
{ {
{IsDirectParentCall, false}, {IsDirectParentCall},
}, },
use_move) {} use_move) {}
~DirectInliner() override = default; ~DirectInliner() override = default;

Loading…
Cancel
Save