From b02e871c1aa277217111a4fa2f3ce2ced40cf921 Mon Sep 17 00:00:00 2001 From: huanghui Date: Tue, 14 Apr 2020 20:30:44 +0800 Subject: [PATCH] [IRFusion] add derelu_fusion pass --- .../ascend/ir_fusion/derelu_fusion.cc | 105 ++++++++++++++++++ .../ascend/ir_fusion/derelu_fusion.h | 33 ++++++ mindspore/ccsrc/pre_activate/common/helper.h | 1 + mindspore/ccsrc/utils/utils.h | 2 + .../ascend/ir_fusion/derelu_fusion_test.cc | 54 +++++++++ .../gtest_input/pre_activate/derelu_fusion.py | 56 ++++++++++ 6 files changed, 251 insertions(+) create mode 100644 mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc create mode 100644 mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.h create mode 100644 tests/ut/cpp/pre_activate/ascend/ir_fusion/derelu_fusion_test.cc create mode 100644 tests/ut/cpp/python_input/gtest_input/pre_activate/derelu_fusion.py diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc new file mode 100644 index 0000000000..d5ea315de1 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/ir_fusion/derelu_fusion.h" +#include +#include +#include "session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "pipeline/static_analysis/abstract_value.h" +#include "pre_activate/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +const size_t kReluV2OutputNum = 2; + +CNodePtr GetRelu(const CNodePtr &relu_grad) { + MS_EXCEPTION_IF_NULL(relu_grad); + if (relu_grad->size() != kReluGradInputNum) { + MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size(); + } + auto relu_anf = relu_grad->input(2); + MS_EXCEPTION_IF_NULL(relu_anf); + return relu_anf->cast(); +} + +CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(relu); + if (relu->size() != kReluInputNum) { + MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size(); + } + + auto prim = std::make_shared(kReluV2OpName); + std::vector inputs = {NewValueNode(prim), relu->input(1)}; + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(relu->scope()); + + // ReluV2's 2rd output is mask whose data type is uint8 and value is 0 or 1, so shape is an empty vector + TypeId mask_dtype = kNumberTypeUInt8; + std::vector mask_shape; + auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), mask_dtype}; + auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get()); + return new_node; +} + +CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, const AnfNodePtr &second_input) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(relu_grad); + MS_EXCEPTION_IF_NULL(second_input); + + auto prim = std::make_shared(kReluGradV2OpName); + std::vector inputs = {NewValueNode(prim), relu_grad->input(1), second_input}; + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(relu_grad->scope()); + new_node->set_abstract(relu_grad->abstract()); + return new_node; +} +} // namespace + +const BaseRef DereluFusion::DefinePattern() const { + VarPtr i0 = std::make_shared(); + VarPtr i1 = std::make_shared(); + VectorRef relu({prim::kPrimRelu, i1}); + VectorRef relu_grad({prim::kPrimReluGrad, i0, relu}); + return relu_grad; +} + +const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto relu_grad = node->cast(); + MS_EXCEPTION_IF_NULL(relu_grad); + auto relu = GetRelu(relu_grad); + MS_EXCEPTION_IF_NULL(relu); + + auto relu_v2 = CreateReluV2(graph, relu); + std::vector relu_v2_node_outputs; + CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); + + auto relu_grad_v2 = CreateReluGradV2(graph, relu_grad, relu_v2_node_outputs[1]); + + auto manage = graph->manager(); + MS_EXCEPTION_IF_NULL(manage); + manage->Replace(relu, relu_v2_node_outputs[0]); + return relu_grad_v2; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.h new file mode 100644 index 0000000000..e1811f4db4 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ + +#include +#include "pre_activate/common/optimizer.h" + +namespace mindspore { +namespace opt { +class DereluFusion : public PatternProcessPass { + public: + explicit DereluFusion(bool multigraph = true) : PatternProcessPass("derelu_fusion", multigraph) {} + ~DereluFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/common/helper.h b/mindspore/ccsrc/pre_activate/common/helper.h index 4f30a935af..4cacd6fbcc 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.h +++ b/mindspore/ccsrc/pre_activate/common/helper.h @@ -29,6 +29,7 @@ constexpr size_t kTransOpInputNum = 2; constexpr size_t kCastInputNum = 2; constexpr size_t kDependInputNum = 3; constexpr size_t kReluInputNum = 2; +constexpr size_t kReluGradInputNum = 3; constexpr size_t kAddInputNum = 3; constexpr size_t kAddNInputNum = 3; constexpr size_t kTupleGetitemInputNum = 3; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 08a98a3129..60d5830933 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -115,6 +115,8 @@ constexpr auto kBiasAddOpName = "BiasAdd"; constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad"; constexpr auto kSendOpName = "Send"; constexpr auto kRecvOpName = "Recv"; +constexpr auto kReluV2OpName = "ReluV2"; +constexpr auto kReluGradV2OpName = "ReluGradV2"; // attr key name constexpr auto kAttrInputNames = "input_names"; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/derelu_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/derelu_fusion_test.cc new file mode 100644 index 0000000000..ffa5a42b4d --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/derelu_fusion_test.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#include "pre_activate/common/optimizer.h" +#include "pre_activate/ascend/ir_fusion/derelu_fusion.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +class TestHWOptimizeDereluFusion : public BackendCommon { + public: + TestHWOptimizeDereluFusion() : get_py_fun_("gtest_input.pre_activate.derelu_fusion", true) {} + ~TestHWOptimizeDereluFusion() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWOptimizeDereluFusion, test_fusion) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_derelu_fusion", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{1, 1, 1, 1}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 2; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_derelu_fusion", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/derelu_fusion.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/derelu_fusion.py new file mode 100644 index 0000000000..497975542b --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/derelu_fusion.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore.ops import operations as P +from mindspore.ops import Primitive + +relu = P.ReLU() +relu_grad = Primitive('ReluGrad') +relu_v2 = Primitive('ReluV2') +relu_grad_v2 = Primitive('ReluGradV2') +make_tuple = Primitive('make_tuple') +tuple_getitem = Primitive('tuple_getitem') + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + +def test_derelu_fusion(tag): + fns = FnDict() + + @fns + def before(i0, i1): + relu_res = relu(i1) + res = relu_grad(i0, relu_res) + other = relu(relu_res) + res = make_tuple(res, other) + return res + + @fns + def after(i0, i1): + relu_res = relu_v2(i1) + item0 = tuple_getitem(relu_res, 0) + item1 = tuple_getitem(relu_res, 1) + other = relu(item0) + res = relu_grad_v2(i0, item1) + res = make_tuple(res, other) + return make_tuple(res) + + return fns[tag]