!4576 Support if by if not inline

Merge pull request !4576 from amongo/SupportIfByIfNotInline
pull/4576/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 4f6e63fcf8

@ -20,12 +20,14 @@
#include <vector>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/tensor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
@ -153,23 +155,31 @@ class InlinerBase : public AnfVisitor {
return nullptr;
}
std::vector<AnfNodePtr> params;
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params));
std::vector<AnfNodePtr> args;
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
// compare size to avoid the case that the function has default value after grad.
// for which after renormalize, the function default value will be an input
if (fg->parameters().size() != params.size()) {
if (fg->parameters().size() != args.size()) {
return nullptr;
}
// Not to inline after block if it has switch call inside, to avoid switch expansion.
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
auto has_branch_call = GraphHasBranch(fg);
if (has_branch_call) {
return TransformBranchCall(fg, node, args);
}
}
if (use_move_ && IsUniqueUse(fg, nullptr)) {
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
ReplaceParams(mng, params, fg);
ReplaceParams(mng, args, fg);
auto out_node = fg->output();
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
return out_node;
}
return InlineClone(fg, node->func_graph(), params, inputs[0]->scope());
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
}
void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
@ -197,11 +207,89 @@ class InlinerBase : public AnfVisitor {
is_checked_ = false;
is_recursive_ = false;
}
// For after block which contains branch call, delete the parameters which is not used.
// In most cases, it may be a `Module` or other constant input.
AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) {
auto &fg_params = fg->parameters();
std::vector<int> used_param_index;
auto mng = fg->manager();
for (size_t i = 0; i < fg_params.size(); i++) {
if (mng->node_users()[fg_params[i]].size() != 0) {
used_param_index.emplace_back(i);
}
}
if (used_param_index.size() != fg_params.size()) {
MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString();
// clone a new graph and ignore the not used parameters
FuncGraphPtr new_fg = TransformableClone(fg);
auto &new_fg_params = new_fg->parameters();
std::vector<AnfNodePtr> new_params;
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params),
[&new_fg_params](size_t i) { return new_fg_params[i]; });
new_fg->set_parameters(new_params);
std::vector<AnfNodePtr> node_inputs;
node_inputs.push_back(NewValueNode(new_fg));
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs),
[&args](size_t i) { return args[i]; });
return node->func_graph()->NewCNode(node_inputs);
}
return nullptr;
}
// This is a try-best algorithm to find a graph which may generate branch call.
// It does not handle high-order function call. For high-orderer call branch, it still may be inlined.
bool GraphHasBranch(FuncGraphPtr fg) {
if (graph_branch_cache_.find(fg) != graph_branch_cache_.end()) {
return graph_branch_cache_[fg];
}
bool has_branch = false;
auto nodes = fg->nodes();
for (auto &item : nodes) {
if (IsPrimitiveCNode(item, prim::kPrimSwitch)) {
auto sw_inputs = item->cast<CNodePtr>()->inputs();
if (sw_inputs.size() != 4) {
MS_LOG(EXCEPTION) << "switch inputs should be 4";
}
if (!sw_inputs[1]->isa<ValueNode>() || IsValueNode<tensor::Tensor>(sw_inputs[1])) {
has_branch = true;
break;
}
} else if (IsCNodeGraph(item)) {
auto cinputs = item->cast<CNodePtr>()->inputs();
if (cinputs.size() < 1) {
MS_LOG(EXCEPTION) << "graph call inputs should greater than 1";
}
FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[0]);
bool call_fg_has_branch = GraphHasBranch(call_fg);
if (call_fg_has_branch) {
has_branch = true;
break;
}
} else if (IsPrimitiveCNode(item, prim::kPrimPartial)) {
auto cinputs = item->cast<CNodePtr>()->inputs();
if (cinputs.size() < 2) {
MS_LOG(EXCEPTION) << "partial call inputs should greater than 2";
}
FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[1]);
if (call_fg == nullptr) {
continue;
}
bool call_fg_has_branch = GraphHasBranch(call_fg);
if (call_fg_has_branch) {
has_branch = true;
break;
}
}
}
graph_branch_cache_[fg] = has_branch;
return has_branch;
}
private:
bool is_checked_{false}, is_recursive_{false};
bool use_move_;
std::vector<std::pair<CriterionFuncType, bool>> criterions_;
std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_;
};
class Inliner : public InlinerBase {

@ -1029,6 +1029,12 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
FunctionBlockPtr after_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
if (MsContext::GetInstance()->backend_policy() != "ge") {
// for backends excludes 'ge', it can handle multi graph call, use this flag to
// generate call not inline `after_block` graph to reduce if by if switch expansion.
after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
}
// process the if-true branch
py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode);

@ -74,6 +74,7 @@ using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
const char FUNC_GRAPH_FLAG_CORE[] = "core";
const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save