You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
211 lines
5.8 KiB
211 lines
5.8 KiB
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_
|
|
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_
|
|
|
|
#include <vector>
|
|
#include <utility>
|
|
#include <algorithm>
|
|
|
|
#include "optimizer/irpass.h"
|
|
#include "optimizer/optimizer.h"
|
|
#include "ir/visitor.h"
|
|
#include "ir/func_graph.h"
|
|
#include "ir/func_graph_cloner.h"
|
|
#include "operator/ops.h"
|
|
|
|
namespace mindspore {
|
|
namespace opt {
|
|
namespace irpass {
|
|
class ReplaceApplicator : public AnfVisitor {
|
|
public:
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
if (!IsValueNode<FuncGraph>(node)) {
|
|
return nullptr;
|
|
}
|
|
|
|
auto fg = GetValueNode<FuncGraphPtr>(node);
|
|
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
|
|
return nullptr;
|
|
}
|
|
|
|
auto out = fg->output();
|
|
MS_EXCEPTION_IF_NULL(out);
|
|
if (!out->isa<CNode>()) {
|
|
return nullptr;
|
|
}
|
|
|
|
auto &inputs = out->cast<CNodePtr>()->inputs();
|
|
auto params = fg->parameters();
|
|
|
|
// Exclude first elements of inputs which is fn.
|
|
auto input_size = inputs.size();
|
|
auto param_size = params.size();
|
|
if ((input_size == 1 && param_size == 0) || (input_size > 1 && (input_size - 1) == param_size &&
|
|
std::equal(inputs.begin() + 1, inputs.end(), params.begin()))) {
|
|
auto inner = inputs[0];
|
|
if (IsValueNode<Primitive>(inner) ||
|
|
(IsValueNode<FuncGraph>(inner) && GetValueNode<FuncGraphPtr>(inner)->parent() == nullptr)) {
|
|
return inner;
|
|
}
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
};
|
|
|
|
using CriterionFuncType = std::function<bool(FuncGraphPtr, AnfNodePtr)>;
|
|
|
|
bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) {
|
|
auto &s = fg->nodes();
|
|
int n_cnode = std::count_if(s.begin(), s.end(), [](const AnfNodePtr &n) {
|
|
MS_EXCEPTION_IF_NULL(n);
|
|
return n->isa<CNode>();
|
|
});
|
|
// There is at least one CNode(return, other_node).
|
|
return n_cnode <= 2;
|
|
}
|
|
|
|
bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) {
|
|
auto &cnodes = fg->func_graph_cnodes_index();
|
|
int n_use =
|
|
std::accumulate(cnodes.begin(), cnodes.end(), 0,
|
|
[](int sum, const std::pair<const CNodeIndexPairPtr, int> &item) { return sum + item.second; });
|
|
return n_use == 1;
|
|
}
|
|
|
|
bool IsInside(FuncGraphPtr, const AnfNodePtr &node) {
|
|
MS_EXCEPTION_IF_NULL(node->func_graph());
|
|
auto &flags = node->func_graph()->flags();
|
|
if (flags.find("inline_inside") != flags.end()) {
|
|
return flags["inline_inside"];
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) {
|
|
auto &flags = fg->flags();
|
|
if (flags.find("core") != flags.end()) {
|
|
return flags["core"];
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; }
|
|
|
|
// {G, Xs}
|
|
class InlinerBase : public AnfVisitor {
|
|
public:
|
|
explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions) : criterions_(criterions) {}
|
|
~InlinerBase() override = default;
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
if (!node->isa<CNode>()) {
|
|
return nullptr;
|
|
}
|
|
|
|
auto &inputs = node->cast<CNodePtr>()->inputs();
|
|
if (inputs.size() < 1 || !IsValueNode<FuncGraph>(inputs[0])) {
|
|
return nullptr;
|
|
}
|
|
|
|
// G
|
|
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
|
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
|
|
return nullptr;
|
|
}
|
|
|
|
Reset();
|
|
bool is_match = false;
|
|
for (auto &criterion : criterions_) {
|
|
if (!criterion.first(fg, node)) {
|
|
continue;
|
|
}
|
|
|
|
if (criterion.second && IsRecursive(fg)) {
|
|
continue;
|
|
}
|
|
|
|
is_match = true;
|
|
break;
|
|
}
|
|
|
|
if (!is_match) {
|
|
return nullptr;
|
|
}
|
|
|
|
std::vector<AnfNodePtr> params;
|
|
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params));
|
|
|
|
if (IsUniqueUse(fg, nullptr)) {
|
|
auto mng = fg->manager();
|
|
MS_EXCEPTION_IF_NULL(mng);
|
|
ReplaceParams(mng, params, fg);
|
|
auto out_node = fg->output();
|
|
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
|
|
return out_node;
|
|
}
|
|
|
|
return InlineClone(fg, node->func_graph(), params, inputs[0]->scope());
|
|
}
|
|
|
|
void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
|
|
const FuncGraphPtr &fg) {
|
|
auto params = fg->parameters();
|
|
auto old_size = params.size();
|
|
if (old_size != new_params.size()) {
|
|
MS_LOG(EXCEPTION) << "Parameter size not match.";
|
|
}
|
|
for (size_t i = 0; i < old_size; i++) {
|
|
(void)mng->Replace(params[i], new_params[i]);
|
|
}
|
|
}
|
|
|
|
bool IsRecursive(const FuncGraphPtr &fg) {
|
|
if (!is_checked_) {
|
|
is_checked_ = true;
|
|
is_recursive_ = fg->recursive();
|
|
}
|
|
return is_recursive_;
|
|
}
|
|
|
|
void Reset() {
|
|
is_checked_ = false;
|
|
is_recursive_ = false;
|
|
}
|
|
|
|
private:
|
|
bool is_checked_{false}, is_recursive_{false};
|
|
std::vector<std::pair<CriterionFuncType, bool>> criterions_;
|
|
};
|
|
|
|
class Inliner : public InlinerBase {
|
|
public:
|
|
Inliner()
|
|
: InlinerBase({
|
|
{IsUniqueUse, true},
|
|
{IsTrivial, false},
|
|
{IsInside, false},
|
|
{IsCore, false},
|
|
{NoCriterion, true},
|
|
}) {}
|
|
~Inliner() override = default;
|
|
};
|
|
} // namespace irpass
|
|
} // namespace opt
|
|
} // namespace mindspore
|
|
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_
|