Refactor and simplify the inlining procedure.

pull/9312/head
Zhang Qinghua 4 years ago
parent b273a46c53
commit 3ea67d4549

@ -70,46 +70,21 @@ class ReplaceApplicator : public AnfVisitor {
}
};
using CriterionFuncType = std::function<bool(FuncGraphPtr, AnfNodePtr)>;
class InlinerBase;
using CriterionFuncType = std::function<bool(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &)>;
bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) {
auto n_cnode = fg->nodes().size() - fg->parameters().size();
// There is at least one CNode(return, other_node).
return n_cnode <= 2;
}
bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) {
auto &cnodes = fg->func_graph_cnodes_index();
int64_t n_use = std::accumulate(
cnodes.begin(), cnodes.end(), 0,
[](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; });
return n_use == 1;
}
bool IsInside(FuncGraphPtr, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node->func_graph());
return node->func_graph()->has_flag("inline_inside");
}
bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); }
bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; }
bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
bool IsDirectParentCall(FuncGraphPtr fg, AnfNodePtr node) {
bool unique_use = IsUniqueUse(fg, nullptr);
bool is_recursive = fg->recursive();
if (fg->parent() != nullptr && is_recursive) {
if (fg->parent() == node->func_graph() && unique_use) {
return true;
}
}
return false;
}
bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node);
bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node);
bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &);
// {G, Xs}
class InlinerBase : public AnfVisitor {
public:
explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions, bool use_move = true)
explicit InlinerBase(std::vector<std::vector<CriterionFuncType>> criterions, bool use_move = true)
: use_move_(use_move), criterions_(criterions) {}
~InlinerBase() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -135,22 +110,24 @@ class InlinerBase : public AnfVisitor {
return nullptr;
}
}
Reset();
// 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...}
// All the criterions of 'criterion group' are true would set 'criterion group' as 'true'. As [AND].
// Anyone of 'criterion group' in 'criterions_' is 'true' would be matched. As [OR].
bool is_match = false;
for (auto &criterion : criterions_) {
if (!criterion.first(fg, node)) {
continue;
for (auto &criterions : criterions_) { // Each 'criterion group' in criterions_.
is_match = true;
for (auto &criterion : criterions) { // Each criterion in 'criterion group'.
if (!criterion(this, fg, node)) {
is_match = false;
break;
}
}
if (criterion.second && IsRecursive(fg)) {
continue;
if (is_match) {
break;
}
is_match = true;
break;
}
if (!is_match) {
return nullptr;
}
@ -162,24 +139,19 @@ class InlinerBase : public AnfVisitor {
if (fg->parameters().size() != args.size()) {
return nullptr;
}
auto is_unique_use = IsUniqueUse(fg, nullptr);
// Not to inline after block if it has switch call inside, to avoid switch expansion.
if (!is_unique_use && fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
auto has_branch_call = GraphHasBranch(fg);
if (has_branch_call) {
return TransformBranchCall(fg, node, args);
if (IsUniqueUse(nullptr, fg, nullptr)) {
if (use_move_) {
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
ReplaceParams(mng, args, fg);
auto out_node = fg->output();
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
return out_node;
}
} else if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK) && GraphHasBranch(fg)) {
// Not to inline after block if it has switch call inside, to avoid switch expansion.
return TransformBranchCall(fg, node, args);
}
if (use_move_ && is_unique_use) {
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
ReplaceParams(mng, args, fg);
auto out_node = fg->output();
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
return out_node;
}
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
}
@ -208,6 +180,7 @@ class InlinerBase : public AnfVisitor {
is_checked_ = false;
is_recursive_ = false;
}
// For after block which contains branch call, delete the parameters which is not used.
// In most cases, it may be a `Module` or other constant input.
AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) {
@ -298,25 +271,62 @@ class InlinerBase : public AnfVisitor {
private:
bool is_checked_{false}, is_recursive_{false};
bool use_move_;
std::vector<std::pair<CriterionFuncType, bool>> criterions_;
std::vector<std::vector<CriterionFuncType>> criterions_;
std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_;
// Key is the old func graph, and the value is the new func_graph
std::unordered_map<FuncGraphPtr, FuncGraphPtr> transformed_branch_chache_;
};
bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
auto &cnodes = fg->func_graph_cnodes_index();
int64_t n_use = std::accumulate(
cnodes.begin(), cnodes.end(), 0,
[](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; });
return n_use == 1;
}
bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
auto n_cnode = fg->nodes().size() - fg->parameters().size();
// There is at least one CNode(return, other_node).
return n_cnode <= 2;
}
bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node->func_graph());
return node->func_graph()->has_flag("inline_inside");
}
bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { return fg->has_flag("core"); }
bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node) {
bool unique_use = IsUniqueUse(nullptr, fg, nullptr);
bool is_recursive = fg->recursive();
if (fg->parent() != nullptr && is_recursive) {
if (fg->parent() == node->func_graph() && unique_use) {
return true;
}
}
return false;
}
bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &) {
return !inliner->IsRecursive(fg);
}
class Inliner : public InlinerBase {
public:
explicit Inliner(bool use_move = true)
: InlinerBase(
// Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
{
{IsUniqueUse, true},
{IsTrivial, false},
{IsInside, false},
{IsCore, false},
{IsDirectParentCall, false},
{NoCriterion, true},
{IsTrivial},
{IsInside},
{IsCore},
{IsNotRecursive},
{IsDirectParentCall},
},
use_move) {}
~Inliner() override = default;
};
@ -324,8 +334,9 @@ class DirectInliner : public InlinerBase {
public:
explicit DirectInliner(bool use_move = true)
: InlinerBase(
// Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
{
{IsDirectParentCall, false},
{IsDirectParentCall},
},
use_move) {}
~DirectInliner() override = default;

Loading…
Cancel
Save