You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/ccsrc/optimizer/irpass/merge_addn.h

321 lines
10 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_
#include <vector>
#include <algorithm>
#include <memory>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "operator/ops.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {PrimAddN, {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}} ->
// {{PrimAddNClass}, {prim::kPrimMakeTuple, Xs, Ys}}
// {PrimAddN, {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}} ->
// {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}}
class MergeAddN : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
Reset();
optimizer_ = optimizer;
is_outer_ = true;
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
if (!is_match_ || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
auto addn = NewValueNode(GetValueNode(cnode->input(0)));
// {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs}
(void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple));
auto fg = node->func_graph();
auto make_node = fg->NewCNode(args_);
return fg->NewCNode({addn, make_node});
}
void Visit(const CNodePtr &cnode) override {
if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
return;
}
auto &inputs = cnode->inputs();
if (is_outer_) {
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Ys_));
is_outer_ = false;
is_inner_ = true;
// {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs[1]);
if (is_match_) {
if (!is_unique(inputs[1])) {
is_match_ = false;
return;
}
(void)Ys_.erase(Ys_.begin());
(void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_));
(void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_));
return;
}
// {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs.back());
if (is_match_) {
if (!is_unique(inputs.back())) {
is_match_ = false;
return;
}
Ys_.pop_back();
(void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_));
(void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_));
return;
}
return;
}
if (is_inner_) {
is_match_ = true;
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
}
}
bool is_unique(const AnfNodePtr &node) {
auto mng = optimizer_->resource()->manager();
auto &node_users = mng->node_users();
if (node_users.find(node) == node_users.end()) {
return false;
}
size_t n_use = node_users[node].size();
return n_use == 1;
}
void Reset() {
Xs_.clear();
Ys_.clear();
args_.clear();
is_inner_ = false;
is_outer_ = false;
is_match_ = false;
}
private:
OptimizerPtr optimizer_{nullptr};
std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{};
bool is_inner_{false}, is_outer_{false}, is_match_{false};
};
// {PrimAddN, {kPrimMakeTuple, Xs}}
class AddNZeroFilter : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
if (filtered_Xs_.empty() || node->func_graph() == nullptr) {
return nullptr;
}
// if only two node in filtered_nodes, {make_tuple, x}. return x.
if (filtered_Xs_.size() == 2) {
return filtered_Xs_[1];
}
// if only one node in filtered_nodes, all node is zerolike, return one of the input.
if (filtered_Xs_.size() == 1 && Xs_.size() > 0) {
return Xs_[0];
}
if (!has_zero_like_) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
auto addn = NewValueNode(GetValueNode(cnode->input(0)));
auto fg = node->func_graph();
auto make_tuple = fg->NewCNode(filtered_Xs_);
return fg->NewCNode({addn, make_tuple});
}
void Visit(const CNodePtr &cnode) override {
if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
return;
}
auto &inputs = cnode->inputs();
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
// {kPrimMakeTuple, X1, X2, ...}
filtered_Xs_.push_back(NewValueNode(prim::kPrimMakeTuple));
for (auto &x : Xs_) {
if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) {
filtered_Xs_.push_back(x);
} else {
has_zero_like_ = true;
}
}
}
void Reset() {
Xs_.clear();
filtered_Xs_.clear();
has_zero_like_ = false;
}
private:
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
bool has_zero_like_{false};
};
// {PrimAddN, {kPrimMakeTuple, Xs}}
// Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd.
// case0: AddN(inputs)(inputs size < 2) -> error
// case1: AddN(inputs)(all inputs is ValueNode) -> error
// case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor)
// case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input)
// -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
class AddNEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(fg);
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
if (fg->recursive()) {
return nullptr;
}
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
mng->AddFuncGraph(new_fg);
need_update_ = false;
bool changed;
do {
changed = Process(new_fg);
} while (changed);
if (!need_update_) {
return nullptr;
} else {
auto new_sx = inputs;
new_sx[0] = NewValueNode(new_fg);
return node->func_graph()->NewCNode(new_sx);
}
}
bool Process(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto nodes = TopoSort(func_graph->output());
bool changed = false;
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &tuple_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(tuple_input);
auto tuple_input_cnode = tuple_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_input_cnode);
auto &tuple_inputs = tuple_input_cnode->inputs();
if (tuple_inputs.size() < 3) {
// case0: inputs size < 2, error
MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2);
}
int valuenode_num =
std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0, [](int accumulator, const AnfNodePtr &node) {
if (IsValueNode<tensor::Tensor>(node)) {
return accumulator + 1;
} else {
return accumulator;
}
});
if (IntToSize(valuenode_num) == tuple_inputs.size()) {
// case1: all inputs is ValueNode, error
MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2);
}
if (tuple_inputs.size() == 3) {
// case2: inputs size = 2, -> TensorAdd(Tensor, Tensor)
MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2);
ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations");
std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1],
tuple_inputs[2]};
mng->Replace(node, func_graph->NewCNode(new_xs));
changed = true;
continue;
}
auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(),
[](const AnfNodePtr &node) { return IsValueNode<tensor::Tensor>(node); });
if (first_valuenode == tuple_inputs.end()) {
// no ValueNode input found.
continue;
} else {
// case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
std::vector<AnfNodePtr> make_tuple_new_xs{
NewValueNode(prim::kPrimMakeTuple),
};
std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(),
[&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) {
if (node != *first_valuenode) {
make_tuple_new_xs.push_back(node);
}
});
ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations");
auto new_addn = func_graph->NewCNode(
{func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)});
ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations");
auto new_add =
func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn});
(void)mng->Replace(node, new_add);
changed = true;
continue;
}
}
need_update_ = need_update_ || changed;
return changed;
}
private:
bool need_update_{false};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_