!7236 Add new pass:arithmetic_simplify and eliminate_empty_graph

Merge pull request !7236 from gengfei/1012_simplify_1.0
pull/7236/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit f4421b4504

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ARITHMETIC_SIMPLIFY_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ARITHMETIC_SIMPLIFY_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace opt {
class ArithmeticSimplify : public Pass {
public:
ArithmeticSimplify() : Pass("arithmetic_simplify") {}
~ArithmeticSimplify() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
using ArithmeticSimplifyPtr = std::shared_ptr<ArithmeticSimplify>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ARITHMETIC_SIMPLIFY_H_

@ -713,26 +713,6 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
std::vector<int> need_inline_;
};
// Eliminate the redundant MakeTuple-GetItem operations.
void EliminateTupleGetItem(const FuncGraphPtr &func_graph) {
auto callback = [](const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) return;
for (size_t i = 1; i < cnode->size(); ++i) {
auto getitem = cnode->input(i);
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue;
auto getitem_cnode = getitem->cast<CNodePtr>();
auto maketuple = getitem_cnode->input(kRealInputNodeIndexInTupleGetItem);
if (!AnfAlgo::CheckPrimitiveType(maketuple, prim::kPrimMakeTuple)) continue;
auto maketuple_cnode = maketuple->cast<CNodePtr>();
int getitem_idx =
GetValue<int>(getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>()->value());
cnode->set_input(i, maketuple_cnode->input(getitem_idx + 1));
}
};
TraverseFuncGraph(func_graph, callback);
}
bool TrySplit(const CNodePtr &sub_root_cnode) {
MS_LOG(INFO) << "Split process node: " << sub_root_cnode->fullname_with_scope();
auto splitter = Splitter::MakeSplitter(sub_root_cnode, std::make_shared<CostModelSplitSchemer>());
@ -761,9 +741,6 @@ bool GraphKernelSplitter::Run(const FuncGraphPtr &func_graph) {
changed = TrySplit(node) || changed;
}
}
if (changed) {
EliminateTupleGetItem(func_graph);
}
mng->RemoveRoots();
mng->KeepRoots({func_graph});
return changed;

@ -43,6 +43,7 @@
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "utils/ms_utils.h"
#include "utils/config_manager.h"
@ -116,7 +117,11 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
// After Simplify and Splitter, a lot of redundant getitem/maketuple
// will be exposed, use GetitemTuple Pass to delete them.
pm->AddPass(std::make_shared<opt::GetitemTuple>());
pm->AddPass(std::make_shared<opt::BindValueToGraph>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);

@ -152,6 +152,49 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > {
mutable AnfNodePtr captured_binop_node_{nullptr};
};
template <typename T>
class PUnaryOperation : public PBase<PUnaryOperation<T> > {
public:
PUnaryOperation(const PrimitivePtr &prim, const T &x) : prim_(prim), x_(x) {}
~PUnaryOperation() = default;
AnfNodePtr GetNode(const AnfNodePtr &node) const {
AnfNodePtrList list = {NewValueNode(prim_), x_.GetNode(node)};
return NewCNode(list, node->func_graph());
}
bool TryCapture_(const AnfNodePtr &node) const {
if (IsPrimitiveCNode(node, prim_)) {
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
if (inputs.size() == 2 && x_.TryCapture(inputs[1])) {
captured_unaryop_node_ = node;
return true;
}
}
return false;
}
AnfNodePtr GetOriginalNode() const {
if (captured_unaryop_node_ == nullptr) {
MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it.";
}
return captured_unaryop_node_;
}
void Reset() const {
x_.Reset();
captured_unaryop_node_ = nullptr;
}
using Internal = const PUnaryOperation<T> &;
private:
const PrimitivePtr prim_;
typename T::Internal x_;
mutable AnfNodePtr captured_unaryop_node_{nullptr};
};
///
/// Helper functions to apply a pattern function on all elements of a tuple
///
@ -681,10 +724,74 @@ class PConstant : public PBase<PConstant<T> > {
return new_vnode;
}
// Support function to multiply two constant tensors: partially support broadcasting shapes
template <typename TD>
TD CalcuConstant(const TD &data, const PrimitivePtr &calcu_type) {
TD tmp_data = data;
if (calcu_type == prim::kPrimReciprocal) {
if (data == 0) {
MS_EXCEPTION(ValueError);
} else {
tmp_data = 1 / data;
}
}
if (calcu_type == prim::kPrimNeg) {
tmp_data = -data;
}
return tmp_data;
}
// calculate const with different operations
AnfNodePtr ValueNodeWithOprations(const PrimitivePtr &calcu_type) {
AnfNodePtr node = this->GetNode(captured_node_);
if (!node->isa<ValueNode>()) {
MS_EXCEPTION(ValueError) << "CalcuValue is trying to use a not ValueNode.";
}
auto value = node->cast<ValueNodePtr>()->value();
if (value->isa<tensor::Tensor>()) {
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat) ||
(tensor_type == TypeId::kNumberTypeFloat64)) {
float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
return nullptr;
}
data2[i] = CalcuConstant(data2[i], calcu_type);
}
}
if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
return nullptr;
}
data2[i] = CalcuConstant(data2[i], calcu_type);
}
}
if (tensor_type == TypeId::kNumberTypeFloat64) {
double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
return nullptr;
}
data2[i] = CalcuConstant(data2[i], calcu_type);
}
}
return node;
}
return nullptr;
}
enum BinOperator {
ADD = 0,
MULTIPLY,
};
// Support function to add/multiply two constant tensors: partially support broadcasting shapes
template <typename TM>
void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
int out_data_size) const {
void CalcByOperator(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
int out_data_size, BinOperator bin_operator) const {
TM *data_1 = reinterpret_cast<TM *>(in_data_1);
TM *data_2 = reinterpret_cast<TM *>(in_data_2);
TM *data_out = new TM[out_data_size];
@ -700,27 +807,42 @@ class PConstant : public PBase<PConstant<T> > {
}
if (in_data_2_size == 1) {
for (int i = 0; i < out_data_size; i++) {
data_out[i] *= data_2[0];
if (bin_operator == ADD) {
data_out[i] += data_2[0];
} else {
data_out[i] *= data_2[0];
}
}
} else {
if (in_data_2_size < out_data_size) {
MS_EXCEPTION(ValueError) << "in_data_2_size is smaller than out_data_size.";
}
for (int i = 0; i < out_data_size; i++) {
data_out[i] *= data_2[i];
if (bin_operator == ADD) {
data_out[i] += data_2[i];
} else {
data_out[i] *= data_2[i];
}
}
}
*out_data = reinterpret_cast<void *>(data_out);
return;
}
AnfNodePtr AddByPatternConst(const PConstant<T> &vpnode_2, const AnfNodePtr &node_3) const {
AnfNodePtr vnode_1 = this->GetNode(captured_node_);
AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_);
return CalcConstantTensors(vnode_1, vnode_2, node_3, ADD);
}
AnfNodePtr MulByPatternConst(const PConstant<T> &vpnode_2, const AnfNodePtr &node_3) const {
AnfNodePtr vnode_1 = this->GetNode(captured_node_);
AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_);
return MulConstantTensors(vnode_1, vnode_2, node_3);
return CalcConstantTensors(vnode_1, vnode_2, node_3, MULTIPLY);
}
AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) const {
AnfNodePtr CalcConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3,
BinOperator bin_operator) const {
if (!vnode_1->isa<ValueNode>() || !vnode_2->isa<ValueNode>() || (vnode_1->abstract() == nullptr) ||
(vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) {
return nullptr;
@ -778,21 +900,21 @@ class PConstant : public PBase<PConstant<T> > {
void *data_out = nullptr;
if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat32) ||
(new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat)) {
Multiply<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
CalcByOperator<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator);
ret = memcpy_s(data, mem_size, data_out, mem_size);
delete[] reinterpret_cast<float *>(data_out);
} else {
if (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat64) {
Multiply<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
CalcByOperator<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator);
ret = memcpy_s(data, mem_size, data_out, mem_size);
delete[] reinterpret_cast<double *>(data_out);
} else {
if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeInt32) ||
(new_tensor_ptr->data_type() == TypeId::kNumberTypeInt)) {
Multiply<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
CalcByOperator<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size, bin_operator);
ret = memcpy_s(data, mem_size, data_out, mem_size);
delete[] reinterpret_cast<int *>(data_out);
} else {
@ -833,6 +955,8 @@ class PConstant : public PBase<PConstant<T> > {
// Arithmetic operations
BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd, true);
BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true);
BIN_OPERATION_PATTERN(operator/, prim::kPrimRealDiv, false);
BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false);
// Macros for match and replace
#define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \

@ -0,0 +1,67 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations as P
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
self.sqrt = P.Sqrt()
self.pow = P.Pow()
self.neg = P.Neg()
def construct(self, x, y):
add_res1 = self.add(x, 4)
add_res2 = self.add(add_res1, 5)
sub_res = self.sub(y, 3)
mul_res = self.mul(self.sqrt(add_res2), self.sqrt(sub_res))
div_res = self.div(mul_res, self.sqrt(mul_res))
pow_res = self.pow(y, 2)
neg_res = self.neg(self.neg(pow_res))
return self.add(div_res, neg_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic():
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
input_y = np.abs(input_y) + 3
add_res = input_x + 9
sub_res = input_y + (-3)
mul_res = np.sqrt(add_res * sub_res)
div_res = np.sqrt(mul_res)
pow_res = input_y * input_y
neg_res = pow_res
expect = div_res + neg_res
net = Net()
result = net(Tensor(input_x), Tensor(input_y))
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res
Loading…
Cancel
Save