|
|
|
@ -20,12 +20,14 @@
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
|
|
|
|
|
#include "frontend/optimizer/irpass.h"
|
|
|
|
|
#include "frontend/optimizer/optimizer.h"
|
|
|
|
|
#include "frontend/optimizer/anf_visitor.h"
|
|
|
|
|
#include "ir/func_graph.h"
|
|
|
|
|
#include "ir/func_graph_cloner.h"
|
|
|
|
|
#include "ir/tensor.h"
|
|
|
|
|
#include "frontend/operator/ops.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
@ -153,23 +155,31 @@ class InlinerBase : public AnfVisitor {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> params;
|
|
|
|
|
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params));
|
|
|
|
|
std::vector<AnfNodePtr> args;
|
|
|
|
|
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
|
|
|
|
|
// compare size to avoid the case that the function has default value after grad.
|
|
|
|
|
// for which after renormalize, the function default value will be an input
|
|
|
|
|
if (fg->parameters().size() != params.size()) {
|
|
|
|
|
if (fg->parameters().size() != args.size()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// Not to inline after block if it has switch call inside, to avoid switch expansion.
|
|
|
|
|
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
|
|
|
|
|
auto has_branch_call = GraphHasBranch(fg);
|
|
|
|
|
if (has_branch_call) {
|
|
|
|
|
return TransformBranchCall(fg, node, args);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (use_move_ && IsUniqueUse(fg, nullptr)) {
|
|
|
|
|
auto mng = fg->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mng);
|
|
|
|
|
ReplaceParams(mng, params, fg);
|
|
|
|
|
ReplaceParams(mng, args, fg);
|
|
|
|
|
auto out_node = fg->output();
|
|
|
|
|
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
|
|
|
|
|
return out_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return InlineClone(fg, node->func_graph(), params, inputs[0]->scope());
|
|
|
|
|
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
|
|
|
|
@ -197,11 +207,89 @@ class InlinerBase : public AnfVisitor {
|
|
|
|
|
is_checked_ = false;
|
|
|
|
|
is_recursive_ = false;
|
|
|
|
|
}
|
|
|
|
|
// For after block which contains branch call, delete the parameters which is not used.
|
|
|
|
|
// In most cases, it may be a `Module` or other constant input.
|
|
|
|
|
AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) {
|
|
|
|
|
auto &fg_params = fg->parameters();
|
|
|
|
|
std::vector<int> used_param_index;
|
|
|
|
|
auto mng = fg->manager();
|
|
|
|
|
for (size_t i = 0; i < fg_params.size(); i++) {
|
|
|
|
|
if (mng->node_users()[fg_params[i]].size() != 0) {
|
|
|
|
|
used_param_index.emplace_back(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (used_param_index.size() != fg_params.size()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString();
|
|
|
|
|
// clone a new graph and ignore the not used parameters
|
|
|
|
|
FuncGraphPtr new_fg = TransformableClone(fg);
|
|
|
|
|
auto &new_fg_params = new_fg->parameters();
|
|
|
|
|
std::vector<AnfNodePtr> new_params;
|
|
|
|
|
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params),
|
|
|
|
|
[&new_fg_params](size_t i) { return new_fg_params[i]; });
|
|
|
|
|
new_fg->set_parameters(new_params);
|
|
|
|
|
std::vector<AnfNodePtr> node_inputs;
|
|
|
|
|
node_inputs.push_back(NewValueNode(new_fg));
|
|
|
|
|
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs),
|
|
|
|
|
[&args](size_t i) { return args[i]; });
|
|
|
|
|
return node->func_graph()->NewCNode(node_inputs);
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// This is a try-best algorithm to find a graph which may generate branch call.
|
|
|
|
|
// It does not handle high-order function call. For high-orderer call branch, it still may be inlined.
|
|
|
|
|
bool GraphHasBranch(FuncGraphPtr fg) {
|
|
|
|
|
if (graph_branch_cache_.find(fg) != graph_branch_cache_.end()) {
|
|
|
|
|
return graph_branch_cache_[fg];
|
|
|
|
|
}
|
|
|
|
|
bool has_branch = false;
|
|
|
|
|
auto nodes = fg->nodes();
|
|
|
|
|
for (auto &item : nodes) {
|
|
|
|
|
if (IsPrimitiveCNode(item, prim::kPrimSwitch)) {
|
|
|
|
|
auto sw_inputs = item->cast<CNodePtr>()->inputs();
|
|
|
|
|
if (sw_inputs.size() != 4) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "switch inputs should be 4";
|
|
|
|
|
}
|
|
|
|
|
if (!sw_inputs[1]->isa<ValueNode>() || IsValueNode<tensor::Tensor>(sw_inputs[1])) {
|
|
|
|
|
has_branch = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
} else if (IsCNodeGraph(item)) {
|
|
|
|
|
auto cinputs = item->cast<CNodePtr>()->inputs();
|
|
|
|
|
if (cinputs.size() < 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "graph call inputs should greater than 1";
|
|
|
|
|
}
|
|
|
|
|
FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[0]);
|
|
|
|
|
bool call_fg_has_branch = GraphHasBranch(call_fg);
|
|
|
|
|
if (call_fg_has_branch) {
|
|
|
|
|
has_branch = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
} else if (IsPrimitiveCNode(item, prim::kPrimPartial)) {
|
|
|
|
|
auto cinputs = item->cast<CNodePtr>()->inputs();
|
|
|
|
|
if (cinputs.size() < 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "partial call inputs should greater than 2";
|
|
|
|
|
}
|
|
|
|
|
FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[1]);
|
|
|
|
|
if (call_fg == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
bool call_fg_has_branch = GraphHasBranch(call_fg);
|
|
|
|
|
if (call_fg_has_branch) {
|
|
|
|
|
has_branch = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
graph_branch_cache_[fg] = has_branch;
|
|
|
|
|
return has_branch;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
bool is_checked_{false}, is_recursive_{false};
|
|
|
|
|
bool use_move_;
|
|
|
|
|
std::vector<std::pair<CriterionFuncType, bool>> criterions_;
|
|
|
|
|
std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Inliner : public InlinerBase {
|
|
|
|
|