From 672244e0ac33648e997ab16d1a7c26b120c15b55 Mon Sep 17 00:00:00 2001 From: liubuyu Date: Wed, 22 Apr 2020 14:43:19 +0800 Subject: [PATCH] add keep_bn_fp32 parameter --- .../pre_activate/ascend/ir_fusion/mul_addn_fusion.cc | 6 +++--- mindspore/nn/optim/optimizer.py | 2 +- mindspore/train/model.py | 12 +++++++++--- .../gtest_input/pre_activate/mul_addn_fusion_test.py | 2 +- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc index 83c58ab547..a5e4675c8f 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc @@ -34,7 +34,7 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const auto prim = std::make_shared(kFusedMulAddNOpName); std::vector inputs = {NewValueNode(prim)}; inputs.push_back(mul->input(kMulInputNum - lossscale_input_index)); - inputs.push_back(addn->input(1)); + inputs.push_back(addn->input(2)); // scalar input should be 3rd input inputs.push_back(mul->input(lossscale_input_index)); auto fusion_node = graph->NewCNode(inputs); @@ -51,7 +51,7 @@ const BaseRef MulAddNFusion::DefinePattern() const { VarPtr Z = std::make_shared(); VectorRef mul({prim::kPrimMul, X, Z}); - VectorRef addn({prim::kPrimAddN, Y, mul}); + VectorRef addn({prim::kPrimAddN, mul, Y}); return addn; } @@ -65,7 +65,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode if (addn == nullptr || addn->inputs().size() != kAddNInputNum) { return nullptr; } - auto mul_anf = addn->input(2); + auto mul_anf = addn->input(1); if (mul_anf == nullptr) { return nullptr; } diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 72593e8001..bab539461e 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -177,7 +177,7 @@ apply_decay = C.MultitypeFuncGraph("apply_decay") def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): """Get grad with weight_decay.""" if if_apply: - return op_add((gradient, weight * weight_decay)) + return op_add((weight * weight_decay, gradient)) return gradient diff --git a/mindspore/train/model.py b/mindspore/train/model.py index f4d1a324d1..698105889a 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -62,6 +62,7 @@ class Model: loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument. e.g. Use `loss_scale_manager=None` to set the value. + keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True. Examples: >>> class Net(nn.Cell): @@ -96,7 +97,10 @@ class Model: self._optimizer = optimizer self._loss_scale_manager = None self._loss_scale_manager_set = False + self._keep_bn_fp32 = True self._check_kwargs(kwargs) + if 'keep_batchnorm_fp32' in kwargs: + self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] if 'loss_scale_manager' in kwargs: self._loss_scale_manager = kwargs['loss_scale_manager'] self._loss_scale_manager_set = True @@ -112,7 +116,7 @@ class Model: def _check_kwargs(self, kwargs): for arg in kwargs: - if arg not in ['loss_scale_manager']: + if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: raise ValueError(f"Unsupport arg '{arg}'") def _build_train_network(self): @@ -124,12 +128,14 @@ class Model: self._optimizer, self._loss_fn, level=self._amp_level, - loss_scale_manager=self._loss_scale_manager) + loss_scale_manager=self._loss_scale_manager, + keep_batchnorm_fp32=self._keep_bn_fp32) else: network = amp.build_train_network(network, self._optimizer, self._loss_fn, - level=self._amp_level) + level=self._amp_level, + keep_batchnorm_fp32=self._keep_bn_fp32) elif self._loss_fn: network = nn.WithLossCell(network, self._loss_fn) # If need to check if loss_fn is not None, but optimizer is None diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py index e5b0a15387..8ce64109c6 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py @@ -42,7 +42,7 @@ def test_mul_addn_fusion(tag): @fns def before(a, b): res = mul(scalar, a) - res = addn((b, res)) + res = addn((res, b)) return res @fns