From ea6958c50a02c87350b26ec2080abc6650f1a045 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Sun, 12 Apr 2020 23:18:04 +0800 Subject: [PATCH] add pattern AdjustAllReduceMulAdd --- mindspore/ccsrc/operator/ops.cc | 1 + mindspore/ccsrc/operator/ops.h | 1 + mindspore/ccsrc/optimizer/irpass.cc | 2 +- .../optimizer/irpass/arithmetic_simplify.h | 78 +++++++++++++++++++ mindspore/ops/operations/array_ops.py | 2 +- mindspore/ops/operations/nn_ops.py | 2 +- tests/ut/cpp/optimizer/lib_test.cc | 19 +++++ .../gtest_input/optimizer/opt_test.py | 45 ++++++++++- 8 files changed, 145 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index f3053cac7d..8cf2d1290f 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -226,6 +226,7 @@ const PrimitivePtr kPrimIsNot = std::make_shared("is_not"); const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); +const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); // Debug ops const PrimitivePtr kPrimScalarSummary = std::make_shared("ScalarSummary"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 2dc7072972..548980bf2d 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -231,6 +231,7 @@ extern const PrimitivePtr kPrimMinimumGrad; extern const PrimitivePtr kPrimMaximumGrad; // Comm ops +extern const PrimitivePtr kPrimAllReduce; extern const PrimitivePtr kPrimMirror; extern const PrimitivePtr kPrimVirtualDiv; extern const PrimitivePtr kPrimVirtualDataset; diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 96d88f6e61..d64df33f99 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -48,7 +48,7 @@ namespace irpass { OptimizeIRPassLib::OptimizeIRPassLib() { arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, - prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); + prim::kPrimAddN, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index ab191aab20..0d48fc1463 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -228,6 +228,82 @@ class ConstantDuplicateMul : public AnfVisitor { CNodePtr cnode_; }; +// grad = AllReduce(grad) / worker_number +// grad = grad + weight * decy +// -> +// grad = grad + weight * decy +// grad = AllReduce(grad) / worker_number + +// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> +// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} +class AdjustAllReduceMulAdd : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + // {prim::kPrimAddN, Zs} + if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { + return nullptr; + } + auto addn = node->cast(); + if (addn->size() != 2) { + return nullptr; + } + + AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); + if (x_ == nullptr || y_ == nullptr || z_ == nullptr) { + return nullptr; + } + + auto fg = node->func_graph(); + AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg); + AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg); + AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg); + return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg); + } + + void Visit(const AnfNodePtr &node) override { + if (level_ == 0) { + level_ = 1; + is_reduce_match_ = false; + // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} + AnfVisitor::Match(prim::kPrimMul)(node); + level_ = 0; + if (is_reduce_match_) { + y_ = tmp_; + } else { + z_ = node; + } + } + + if (level_ == 1) { + // {prim::kPrimAllReduce, X} + if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { + auto cnode = node->cast(); + if (cnode->size() > 1) { + x_ = cnode->input(1); + is_reduce_match_ = true; + } + } else { + tmp_ = node; + } + } + } + + void Reset() { + level_ = 0; + is_reduce_match_ = false; + x_ = nullptr; + y_ = nullptr; + z_ = nullptr; + tmp_ = nullptr; + } + + private: + int level_{0}; + bool is_reduce_match_{false}; + AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; +}; + class ArithmeticSimplify { public: ArithmeticSimplify() @@ -243,6 +319,7 @@ class ArithmeticSimplify { eliminaters_.emplace_back(identity_); eliminaters_.emplace_back(opt_update_zero_tensor_); eliminaters_.emplace_back(constant_duplicate_mul_); + eliminaters_.emplace_back(adjust_allreduce_mul_add_); } ~ArithmeticSimplify() = default; @@ -264,6 +341,7 @@ class ArithmeticSimplify { PrimEliminater identity_; OptUpdateZeroTensor opt_update_zero_tensor_; ConstantDuplicateMul constant_duplicate_mul_; + AdjustAllReduceMulAdd adjust_allreduce_mul_add_; std::vector eliminaters_{}; }; } // namespace irpass diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index a7c3f50440..b348e9a700 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1235,7 +1235,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. Examples: - >>> input_x = Tensor([1, 2, 3, 4], mindspore.float) + >>> input_x = Tensor([1, 2, 3, 4], mindspore.float32) >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32) >>> num_segments = 4 >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 62265162a9..3f24996b1b 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1572,7 +1572,7 @@ class LayerNorm(Primitive): `Layer Normalization `_. .. math:: - y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta + y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 2d4cf0e78e..8e348c698a 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -556,5 +556,24 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) { ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); } + +TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { + FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell"); + FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr"); + FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl"); + FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr"); + FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1"); + FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r"); + FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l"); + FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2"); + auto patterns = std::vector({irpass.arithmetic_simplify_}); + ASSERT_TRUE(CheckOpt(beforell, after1, patterns)); + ASSERT_TRUE(CheckOpt(beforelr, after1, patterns)); + ASSERT_TRUE(CheckOpt(beforerl, after1, patterns)); + ASSERT_TRUE(CheckOpt(beforerr, after1, patterns)); + ASSERT_TRUE(CheckOpt(before2l, after2, patterns)); + ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); +} + } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index d494ad27d3..d74aa15952 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -908,8 +908,8 @@ def test_print_tuple_wrapper(tag): def test_constant_duplicate_mul(tag): fns = FnDict() - Mul = Primitive('Mul'); - Sqrt = Primitive('Sqrt'); + Mul = Primitive('Mul') + Sqrt = Primitive('Sqrt') x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32')) tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) @@ -936,3 +936,44 @@ def test_constant_duplicate_mul(tag): return Mul(Sqrt(x), Mul(tensor1, tensor2)) return fns[tag] + + +def test_adjust_allreduce_mul_add(tag): + fns = FnDict() + Mul = Primitive('Mul') + AddN = Primitive('AddN') + AllReduce = Primitive('AllReduce') + + @fns + def beforell(x, y, z): + return AddN((z, Mul(y, AllReduce(x)))) + + @fns + def beforelr(x, y, z): + return AddN((z, Mul(AllReduce(x), y))) + + @fns + def beforerl(x, y, z): + return AddN((Mul(y, AllReduce(x)), z)) + + @fns + def beforerr(x, y, z): + return AddN((Mul(AllReduce(x), y), z)) + + @fns + def after1(x, y, z): + return Mul(AllReduce(AddN((z, x))), y) + + @fns + def before2r(x, y, z): + return AddN((Mul(AllReduce(x), y), Mul(z, z))) + + @fns + def before2l(x, y, z): + return AddN((Mul(z, z), Mul(AllReduce(x), y))) + + @fns + def after2(x, y, z): + return Mul(AllReduce(AddN((Mul(z, z), x))), y) + + return fns[tag]