From 6d195f340c1a6f1ed09f769a4dfccc6fdffca70d Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Fri, 19 Feb 2021 09:09:42 +0800 Subject: [PATCH] add SyncBatchNorm --- .../ascend/ascend_backend_optimization.cc | 2 + .../ascend/ir_fission/bn_grad_split.cc | 43 ++++ .../ascend/ir_fission/bn_grad_split.h | 8 + .../optimizer/ascend/ir_fission/bn_split.cc | 111 ++++++++++ .../optimizer/ascend/ir_fission/bn_split.h | 13 ++ mindspore/core/base/core_ops.h | 2 + mindspore/nn/layer/normalization.py | 207 +++++++++++++++--- mindspore/ops/_grad/grad_other_ops.py | 17 ++ mindspore/ops/operations/_grad_ops.py | 18 ++ mindspore/ops/operations/_inner_ops.py | 91 ++++++++ .../ascend/ir_fission/bn_grad_split_test.cc | 62 ++++++ .../ascend/ir_fission/bn_split_test.cc | 51 ++++- .../gtest_input/pre_activate/bn_grad_split.py | 39 ++++ .../gtest_input/pre_activate/bn_split.py | 34 +++ tests/ut/python/ops/test_ops.py | 10 + 15 files changed, 672 insertions(+), 36 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index dc49cd4b11..5a66901c64 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -280,6 +280,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr("ir_fusion_pm"); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc index 0cd9acec47..96233f122c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc @@ -18,6 +18,7 @@ #include #include +#include "backend/optimizer/ascend/ir_fission/bn_split.h" #include "utils/utils.h" #include "utils/ms_context.h" #include "backend/optimizer/common/helper.h" @@ -104,6 +105,36 @@ CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode MS_EXCEPTION_IF_NULL(make_tuple); return make_tuple; } + +CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + std::vector bn_update_grad_outputs; + CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs); + if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { + MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size" + << " trace: " << trace::DumpSourceLines(cnode); + } + + std::vector allreduce_mul_outputs; + for (size_t i = 0; i < bn_update_grad_outputs.size(); ++i) { + auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_update_grad_outputs[i], cnode); + allreduce_mul_outputs.emplace_back(allreduce_mul_output); + } + + std::vector bn_reduce_grad_outputs; + CreateOutputsOfReduceGrad(func_graph, cnode, allreduce_mul_outputs, &bn_reduce_grad_outputs); + if (bn_reduce_grad_outputs.size() != 1) { + MS_LOG(EXCEPTION) << "bn_reduce_grad_outputs has wrong size" + << " trace: " << trace::DumpSourceLines(cnode); + } + + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_reduce_grad_outputs[0], + allreduce_mul_outputs[0], allreduce_mul_outputs[1]}; + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + return make_tuple; +} } // namespace const BaseRef BnGradSplit::DefinePattern() const { @@ -120,5 +151,17 @@ const AnfNodePtr BnGradSplit::Process(const FuncGraphPtr &func_graph, const AnfN } return BNGradSplitForTBE(func_graph, cnode); } + +const BaseRef SyncBnGradSplit::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimSyncBatchNormGrad, Xs}); +} + +const AnfNodePtr SyncBnGradSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + return SyncBNGradSplitForTBE(func_graph, cnode); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h index 2e5b512dd1..aae516d5aa 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h @@ -28,6 +28,14 @@ class BnGradSplit : public PatternProcessPass { const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; }; + +class SyncBnGradSplit : public PatternProcessPass { + public: + explicit SyncBnGradSplit(bool multigraph = true) : PatternProcessPass("sync_bn_grad_split", multigraph) {} + ~SyncBnGradSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc index 2c4d8065c8..15d6101d15 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc @@ -17,6 +17,8 @@ #include #include +#include +#include #include "utils/utils.h" #include "utils/ms_context.h" @@ -28,6 +30,9 @@ namespace mindspore { namespace opt { namespace { +constexpr auto kReduceOpSum = "sum"; +constexpr auto kDeviceNum = "device_num"; + bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, std::vector *bn_training_reduce_outputs) { MS_EXCEPTION_IF_NULL(graph); @@ -117,8 +122,105 @@ AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr // Create BNTrainingUpdate node return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, bn_training_reduce_outputs); } + +AnfNodePtr SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetInputTensorNum(cnode) < kBnInputTensorNum) { + MS_LOG(INFO) << "op[" << cnode->DebugString() << "] has less input than " << kBnInputTensorNum << " inputs."; + return nullptr; + } + // Create BNTrainingReduce node and get outputs of BNTrainingReduce + std::vector bn_training_reduce_outputs; + if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) { + MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split"; + return nullptr; + } + if (bn_training_reduce_outputs.size() != kBN1OutputNum) { + MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail" + << " trace: " << trace::DumpSourceLines(node); + } + + std::vector allreduce_mul_outputs; + for (size_t i = 0; i < bn_training_reduce_outputs.size(); ++i) { + auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_training_reduce_outputs[i], cnode); + allreduce_mul_outputs.emplace_back(allreduce_mul_output); + } + + // Create BNTrainingUpdate node + return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, allreduce_mul_outputs); +} } // namespace +AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(sync_bn_cnode); + if (!AnfAlgo::HasNodeAttr(kDeviceNum, sync_bn_cnode)) { + MS_LOG(EXCEPTION) << "op[" << sync_bn_cnode->DebugString() << "] does not have attr device_num."; + } + auto device_num = AnfAlgo::GetNodeAttr(sync_bn_cnode, kDeviceNum); + MS_LOG(INFO) << "device_num value: " << device_num; + float device_num_reciprocal = 1.0 / device_num; + + std::vector device_num_shape = {}; + auto device_num_reciprocal_tensor = std::make_shared(kNumberTypeFloat32, device_num_shape); + MS_EXCEPTION_IF_NULL(device_num_reciprocal_tensor); + auto data_ptr = device_num_reciprocal_tensor->data_c(); + MS_EXCEPTION_IF_NULL(data_ptr); + auto *val = reinterpret_cast(data_ptr); + *val = device_num_reciprocal; + + auto kernel_graph = graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto abstract = std::make_shared(kFloat32, device_num_shape); + auto device_num_reciprocal_value = kernel_graph->NewValueNode(abstract, device_num_reciprocal_tensor); + MS_EXCEPTION_IF_NULL(device_num_reciprocal_value); + kernel_graph->AddValueNodeToGraph(device_num_reciprocal_value); + return device_num_reciprocal_value; +} + +AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input, + const CNodePtr &sync_bn_cnode) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(allreduce_input); + MS_EXCEPTION_IF_NULL(sync_bn_cnode); + + // create AllReduce + std::vector allreduce_inputs = {NewValueNode(std::make_shared(kAllReduceOpName)), + allreduce_input}; + auto allreduce = graph->NewCNode(allreduce_inputs); + MS_EXCEPTION_IF_NULL(allreduce); + allreduce->set_abstract(allreduce_input->abstract()); + allreduce->set_scope(allreduce_input->scope()); + AnfAlgo::SetNodeAttr(kAttrOp, MakeValue(kReduceOpSum), allreduce); + AnfAlgo::CopyNodeAttr(kAttrGroup, sync_bn_cnode, allreduce); + // use SyncBatchNorm's opid as AllReduce's fusion attr + auto sync_bn_opname = sync_bn_cnode->fullname_with_scope(); + auto opid_pos = sync_bn_opname.rfind("-op"); + if (opid_pos == std::string::npos) { + MS_LOG(EXCEPTION) << "op[" << sync_bn_cnode->DebugString() << "] has no opid."; + } + int64_t opid = std::stol(sync_bn_opname.substr(opid_pos + 3)); + // user defined fusion should be greater than 1 + if (opid < 2) { + opid = opid - 2 + std::numeric_limits::max(); + } + AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(opid), allreduce); + + // create Mul + auto device_num_reciprocal_vnode = CreateValueNodeOfDeviceNumReciprocal(graph, sync_bn_cnode); + std::vector mul_inputs = {NewValueNode(std::make_shared(kMulOpName)), allreduce, + device_num_reciprocal_vnode}; + auto mul = graph->NewCNode(mul_inputs); + MS_EXCEPTION_IF_NULL(mul); + mul->set_abstract(allreduce_input->abstract()); + mul->set_scope(allreduce_input->scope()); + return mul; +} + const BaseRef BnSplit::DefinePattern() const { VarPtr Xs = std::make_shared(); MS_EXCEPTION_IF_NULL(Xs); @@ -132,5 +234,14 @@ const AnfNodePtr BnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodeP } return SplitBatchNormForTBE(func_graph, node); } + +const BaseRef SyncBnSplit::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimSyncBatchNorm, Xs}); +} + +const AnfNodePtr SyncBnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + return SyncBNSplitForTBE(func_graph, node); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h index d14d1357a6..b0def9d224 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h @@ -28,6 +28,19 @@ class BnSplit : public PatternProcessPass { const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; }; + +class SyncBnSplit : public PatternProcessPass { + public: + explicit SyncBnSplit(bool multigraph = true) : PatternProcessPass("sync_bn_split", multigraph) {} + ~SyncBnSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode); + +AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input, + const CNodePtr &sync_bn_cnode); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_SPLIT_H_ diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 6c063d349f..1017f87880 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -222,6 +222,8 @@ inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared( inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared("FusedBatchNormGradEx"); inline const PrimitivePtr kPrimBatchNorm = std::make_shared("BatchNorm"); inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared("BatchNormGrad"); +inline const PrimitivePtr kPrimSyncBatchNorm = std::make_shared("SyncBatchNorm"); +inline const PrimitivePtr kPrimSyncBatchNormGrad = std::make_shared("SyncBatchNormGrad"); inline const PrimitivePtr kPrimReluGrad = std::make_shared("ReluGrad"); inline const PrimitivePtr kPrimReluGradV2 = std::make_shared("ReluGradV2"); inline const PrimitivePtr kPrimRelu6Grad = std::make_shared("ReLU6Grad"); diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index f8107806d2..3acac0b5df 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -13,12 +13,17 @@ # limitations under the License. # ============================================================================ """normalization""" +import itertools + from mindspore.ops import operations as P from mindspore.ops import functional as F +from mindspore.ops.operations import _inner_ops as inner from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer +from mindspore.common._decorator import deprecated from mindspore.ops.primitive import constexpr import mindspore.context as context +from mindspore._checkparam import Rel from mindspore._checkparam import Validator as validator from mindspore._extends import cell_attr_register from mindspore.communication.management import get_group_size, get_rank @@ -26,8 +31,9 @@ from mindspore.communication import management from mindspore.ops import _selected_ops from ..cell import Cell -__all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', 'InstanceNorm2d'] +__all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm2d'] +SYNC_BN_GROUP_NAME = "" class _BatchNorm(Cell): """Batch Normalization base class.""" @@ -44,6 +50,7 @@ class _BatchNorm(Cell): moving_var_init='ones', use_batch_statistics=None, device_num_each_group=1, + process_groups=0, input_dims='2d', data_format='NCHW'): super(_BatchNorm, self).__init__() @@ -68,19 +75,47 @@ class _BatchNorm(Cell): gamma_init, num_features), name="gamma", requires_grad=affine) self.beta = Parameter(initializer( beta_init, num_features), name="beta", requires_grad=affine) - self.group = validator.check_positive_int(device_num_each_group) + self.group_device_num = validator.check_positive_int(device_num_each_group) + self.process_groups = process_groups self.is_global = False - if self.group != 1: + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + global SYNC_BN_GROUP_NAME + # for GlobalBatchNorm + if self.group_device_num != 1 and self.parallel_mode != context.ParallelMode.STAND_ALONE: self.rank_id = get_rank() self.rank_size = get_group_size() self.device_list = [i for i in range(0, self.rank_size)] - self.rank_list = self.list_group(self.device_list, self.group) + self.rank_list = self.list_group(self.device_list, self.group_device_num) self.rank_list_idx = len(self.rank_list) for i in range(self.rank_list_idx): - if self.rank_id in self.rank_list[i] and self.group != 1: + if self.rank_id in self.rank_list[i]: self.is_global = True - management.create_group('group' + str(i), self.rank_list[i]) - self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1) + if SYNC_BN_GROUP_NAME == "": + SYNC_BN_GROUP_NAME = "sync_bn_group"+ str(i) + management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i]) + # for SyncBatchNorm + if self.process_groups != 0 and self.parallel_mode != context.ParallelMode.STAND_ALONE: + self.rank_id = get_rank() + self.rank_size = get_group_size() + if self.process_groups is not None: + validator.check_isinstance("process_groups", self.process_groups, list) + self._check_rank_ids(self.process_groups, self.rank_size) + for i in range(len(self.process_groups)): + validator.check_isinstance("process_groups[" + str(i) +"]", self.process_groups[i], list) + self.group_device_num = len(self.process_groups[i]) + if self.rank_id in self.process_groups[i] and self.group_device_num > 1: + self.is_global = True + if SYNC_BN_GROUP_NAME == "": + SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i) + management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i]) + elif self.rank_size > 1: + self.is_global = True + self.group_device_num = self.rank_size + self.device_list = [i for i in range(0, self.rank_size)] + if SYNC_BN_GROUP_NAME == "": + SYNC_BN_GROUP_NAME = "sync_bn_group0" + management.create_group(SYNC_BN_GROUP_NAME, self.device_list) + self.shape = P.Shape() self.reduce_mean = P.ReduceMean(keep_dims=True) self.square = P.Square() @@ -109,9 +144,12 @@ class _BatchNorm(Cell): self.bn_train = P.FusedBatchNorm(mode=1, epsilon=self.eps, momentum=self.momentum) + if self.is_global: + self.bn_train = inner.SyncBatchNorm(epsilon=self.eps, + momentum=self.momentum, + group=SYNC_BN_GROUP_NAME, + device_num=self.group_device_num) self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) - self.enable_global_sync = self.is_global and (self.is_ge_backend or\ - (self.is_graph_mode and self._target == "Ascend")) data_parallel_strategy = ((1,), (1,)) data_parallel_strategy_one = ((1,), ()) @@ -135,26 +173,13 @@ class _BatchNorm(Cell): group_list = [list(i) for i in world_rank_list] return group_list - def _global_sync(self, x, axes, re_shape): - """calculate global batch normalization output""" - x_mean = self.reduce_mean(x, axes) - x_mean_square = self.reduce_mean(self.square(x), axes) - global_batch_mean = self.all_reduce(x_mean) / self.group - global_batch_mean_square = self.all_reduce(x_mean_square) / self.group - global_mean = global_batch_mean - global_var = global_batch_mean_square - self.square(global_mean) - var_sqrt = self.sqrt(global_var + self.eps) - mean_first = (x - global_mean) / var_sqrt - y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape) - - mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean) - tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) - mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var) - tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) - y = F.depend(y, self.assign_sub_mean(self.moving_mean, self.reshape(tmp_mean, self.shape(self.moving_mean)))) - y = F.depend(y, self.assign_sub_var(self.moving_variance, - self.reshape(tmp_variance, self.shape(self.moving_variance)))) - return y + def _check_rank_ids(self, process_groups, rank_size): + seen = set() + for rid in itertools.chain(*process_groups): + validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups") + if rid in seen: + raise ValueError("rank id in process_groups should not be duplicated.") + seen.add(rid) def construct(self, x): _shape_check_bn(self.shape(x), self.input_dims) @@ -164,10 +189,6 @@ class _BatchNorm(Cell): flag = self.use_batch_statistics if flag: - if self.enable_global_sync: - axes, re_shape = _shape_infer(F.shape(x), self.num_features) - return self._global_sync(x, axes, re_shape) - return self.bn_train(x, self.gamma, self.beta, @@ -597,6 +618,7 @@ class GlobalBatchNorm(_BatchNorm): [ 20.9999895 241.9988 ]]]] """ + @deprecated("1.2", "SyncBatchNorm", True) def __init__(self, num_features, eps=1e-5, @@ -619,8 +641,8 @@ class GlobalBatchNorm(_BatchNorm): use_batch_statistics, device_num_each_group, input_dims='both') - self.group = validator.check_positive_int(device_num_each_group) - if self.group <= 1: + self.group_device_num = validator.check_positive_int(device_num_each_group) + if self.group_device_num <= 1: raise ValueError("the number of group must be greater than 1.") def _check_data_dim(self, x): @@ -628,6 +650,121 @@ class GlobalBatchNorm(_BatchNorm): pass +class SyncBatchNorm(_BatchNorm): + r""" + Sync Batch normalization layer over a N-dimension input. + + Sync Batch Normalization is cross device synchronized batch normalization. The implementation of Batch + Normalization only normalizes the data within each device. Sync Batch normalization will normalize the input + within the group. It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by + Reducing Internal Covariate Shift `_. It rescales and recenters the + feature using a mini-batch of data and the learned parameters which can be described in the following formula. + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Note: + Currently, SyncBatchNorm only supports 2D and 4D inputs. + + Args: + num_features (int): `C` from an expected input of size (N, C, H, W). + eps (float): A value added to the denominator for numerical stability. Default: 1e-5. + momentum (float): A floating hyperparameter of the momentum for the + running_mean and running_var computation. Default: 0.9. + affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True. + gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, + use the mean value and variance value of specified value. If None, training process will use the mean and + variance of current batch data and track the running mean and variance, eval process will use the running + mean and variance. Default: None. + process_groups (list): A list to divide devices into different sync groups, containing N subtraction lists. + Each subtraction list contains int numbers identifying rank ids which need to be synchronized in the same + group. All int values must be in [0, rank_size) and different from each other. Default: None, indicating + synchronization across all devices. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Raises: + TypeError: If `num_features` is not an int. + TypeError: If `eps` is not a float. + TypeError: If `process_groups` is not a list. + ValueError: If `num_features` is less than 1. + ValueError: If `momentum` is not in range [0, 1]. + ValueError: If `device_num_each_group` is less than 2. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> # This example should be run with multiple processes. + >>> # Please refer to the tutorial > Distributed Training on mindspore.cn. + >>> import numpy as np + >>> from mindspore.communication import init + >>> from mindspore import context + >>> from mindspore.context import ParallelMode + >>> from mindspore import nn, Tensor + >>> from mindspore.common import dtype as mstype + >>> + >>> context.set_context(mode=context.GRAPH_MODE) + >>> init() + >>> context.reset_auto_parallel_context() + >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) + >>> np.random.seed(0) + >>> sync_bn_op = nn.SyncBatchNorm(num_features=3, process_groups=[[0, 1], [2, 3]]) + >>> input = Tensor(np.random.randint(0, 255, [1, 3, 2, 2]), mstype.float32) + >>> output = sync_bn_op(input) + >>> print(output) + [[[[171.99915 46.999763] + [116.99941 191.99904 ]] + [[ 66.999664 250.99875 ] + [194.99902 102.99948 ]] + [[ 8.999955 210.99895 ] + [ 20.9999895 241.9988 ]]]] + """ + + def __init__(self, + num_features, + eps=1e-5, + momentum=0.9, + affine=True, + gamma_init='ones', + beta_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=None, + process_groups=None): + super(SyncBatchNorm, self).__init__(num_features, + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics, + process_groups=process_groups, + input_dims='both') + + def _check_data_dim(self, x): + if x.dim == 0: + pass + + class LayerNorm(Cell): r""" Applies Layer Normalization over a mini-batch of inputs. diff --git a/mindspore/ops/_grad/grad_other_ops.py b/mindspore/ops/_grad/grad_other_ops.py index eceacd2133..e1b4453a23 100644 --- a/mindspore/ops/_grad/grad_other_ops.py +++ b/mindspore/ops/_grad/grad_other_ops.py @@ -17,6 +17,8 @@ from .. import operations as P from .. import composite as C +from ..operations import _grad_ops as G +from ..operations import _inner_ops as inner from ..composite.multitype_ops.zeros_like_impl import zeros_like from .grad_base import bprop_getters @@ -64,5 +66,20 @@ def bprop_pqc(self): dx = t(dx, (1, 0)) dy = C.tensor_dot(dout[0], out[2], ((0, 1), (0, 1))) return dx, dy + return bprop + + +@bprop_getters.register(inner.SyncBatchNorm) +def get_bprop_sync_batch_norm(self): + """Grad definition for `SyncBatchNorm` operation.""" + input_grad = G.SyncBatchNormGrad(self.epsilon, self.group, self.device_num) + def bprop(x, scale, b, mean, variance, out, dout): + saved_mean = out[3] + saved_variance = out[4] + out = input_grad(dout[0], x, scale, saved_mean, saved_variance) + dx = out[0] + dscale = out[1] + dbias = out[2] + return dx, dscale, dbias, zeros_like(mean), zeros_like(variance) return bprop diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index e25af8e2bb..71638a73f8 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -204,6 +204,24 @@ class BatchNormGrad(PrimitiveWithInfer): return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) +class SyncBatchNormGrad(PrimitiveWithInfer): + """Performs grad of SyncBatchNorm operation.""" + + @prim_attr_register + def __init__(self, epsilon=1e-5, group="group0", device_num=2): + validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) + if not isinstance(group, str): + raise TypeError("The group attr of SyncBatchNormGrad should be str.") + validator.check_int(device_num, 2, Rel.GE, "device_num", self.name) + + def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape): + validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) + return (x_shape, scale_shape, scale_shape) + + def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_shape, save_variance_shape): + return (x_type, scale_type, scale_type) + + class BiasAddGrad(PrimitiveWithInfer): """Computes gradients of BiasAdd.""" diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 1b3ce2a731..a73c821ee0 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -630,6 +630,7 @@ class GpuConvertToDynamicShape(PrimitiveWithCheck): def check_dtype(self, input_dtype): validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name) + class ErrorOnDynamicShapeInput(PrimitiveWithInfer): """ This op is used for dynamic shape testing. The only purpose of this operator is @@ -724,3 +725,93 @@ class SequenceMask(PrimitiveWithCheck): def check_dtype(self, lengths_dtype, maxlen_dtype): validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name) validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name) + + +class SyncBatchNorm(PrimitiveWithInfer): + r""" + Sync Batch Normalization for input data and updated parameters. + + Sync Batch Normalization is cross device synchronized batch normalization. Batch Normalization is + widely used in convolutional neural networks. This operation applies Batch Normalization over input + to avoid internal covariate shift as described in the paper `Batch Normalization: Accelerating + Deep Network Training by Reducing Internal Covariate Shift `_. + It rescales and recenters the features using a mini-batch of data and the learned parameters which + can be described in the following formula, + + .. math:: + y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta + + where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. + + Args: + epsilon (float): A small value added for numerical stability. Default: 1e-5. + momentum (float): The hyper parameter to compute moving average for running_mean and running_var + (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`). + Momentum value must be [0, 1]. Default: 0.1. + group (str): The communication group to work on. Default: "sync_bn_group0". + device_num (int): The number of devices in each group. Default: 2. + + Inputs: + - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type. + - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. + - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`. + - **mean** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. + - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `mean`. + + Outputs: + Tuple of 5 Tensor, the normalized inputs and the updated parameters. + + - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`. + - **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`. + - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`. + - **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`. + - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> # This example should be run with multiple processes. + >>> # Please refer to nn.SyncBatchNorm for direct use. + >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32) + >>> scale = Tensor(np.ones([2]), mindspore.float32) + >>> bias = Tensor(np.ones([2]), mindspore.float32) + >>> mean = Tensor(np.ones([2]), mindspore.float32) + >>> variance = Tensor(np.ones([2]), mindspore.float32) + >>> sync_batch_norm = ops._inner_ops.SyncBatchNorm() + >>> output = sync_batch_norm(input_x, scale, bias, mean, variance) + >>> print(output) + (Tensor(shape=[2, 2], dtype=Float32, value= + [[ 1.00000000e+00, 1.00000000e+00], + [ 1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[2], dtype=Float32, value= + [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value= + [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value= + [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value= + [ 1.00000000e+00, 1.00000000e+00])) + """ + + @prim_attr_register + def __init__(self, epsilon=1e-5, momentum=0.1, group="sync_bn_group0", device_num=2): + validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) + validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) + validator.check_isinstance("group", group, str) + validator.check_int(device_num, 2, Rel.GE, "device_num", self.name) + self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], + outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) + + def infer_shape(self, input_x, scale, bias, mean, variance): + validator.check_equal_int(len(scale), 1, "scale rank", self.name) + validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) + validator.check("scale shape[0]", scale[0], "input_x channel", input_x[1], Rel.EQ, self.name) + validator.check_equal_int(len(mean), 1, "mean rank", self.name) + validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) + validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) + return (input_x, scale, scale, scale, scale) + + def infer_dtype(self, input_x, scale, bias, mean, variance): + validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name) + args = {"scale": scale, "bias": bias} + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) + args_moving = {"mean": mean, "variance": variance} + validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name) + return (input_x, scale, bias, input_x, input_x) diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_grad_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_grad_split_test.cc index 455613bc61..aa87bb560f 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_grad_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_grad_split_test.cc @@ -100,5 +100,67 @@ TEST_F(TestHWBnGradSplit, test_bn_grad_split_tbe) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_bn_grad_split", "after2"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWBnGradSplit, test_sync_bn_grad_split_tbe) { + get_py_fun_.SetDoResolve(true); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_sync_bn_grad_split", "before"); + ASSERT_TRUE(g != nullptr); + std::vector shp_x{1, 64, 112, 112}; + std::vector shp_b{64}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + auto b_abstract = std::make_shared(kFloat32, shp_b); + AbstractBasePtrList args_spec_list{x_abstract, x_abstract, b_abstract, b_abstract, b_abstract}; + auto kernel_graph = GetKernelGraph(g, args_spec_list); + EXPECT_NE(kernel_graph, nullptr); + + // get SyncBNGrad + CNodePtr ret = kernel_graph->get_return(); + EXPECT_NE(ret, nullptr); + EXPECT_NE(ret->input(1), nullptr); + EXPECT_TRUE(ret->input(1)->isa()); + auto make_tuple1 = ret->input(1)->cast(); + EXPECT_NE(make_tuple1->input(1), nullptr); + EXPECT_TRUE(make_tuple1->input(1)->isa()); + auto make_tuple2 = make_tuple1->input(1)->cast(); + EXPECT_NE(make_tuple2->input(1), nullptr); + EXPECT_TRUE(make_tuple2->input(1)->isa()); + auto tuple_getitem = make_tuple2->input(1)->cast(); + EXPECT_NE(tuple_getitem->input(1), nullptr); + EXPECT_TRUE(tuple_getitem->input(1)->isa()); + auto bn_grad = tuple_getitem->input(1)->cast(); + + // get param1 + EXPECT_NE(bn_grad->input(1), nullptr); + auto param1 = bn_grad->input(1); + + // set kernel for param1 + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder2; + builder2.SetOutputsFormat({kOpFormat_NC1HWC0}); + builder2.SetOutputsDeviceType({kNumberTypeFloat32}); + AnfAlgo::SetSelectKernelBuildInfo(builder2.Build(), param1.get()); + + // set kernel for SyncBNGrad + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1; + builder1.SetInputsFormat( + {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); + builder1.SetOutputsFormat( + {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); + builder1.SetInputsDeviceType( + {kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32}); + builder1.SetOutputsDeviceType( + {kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32}); + builder1.SetKernelType(TBE_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), bn_grad.get()); + // do sync_bn_grad_split pass + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pass = std::make_shared(); + pm->AddPass(pass); + optimizer->AddPassManager(pm); + auto new_graph = optimizer->Optimize(kernel_graph); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_sync_bn_grad_split", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_split_test.cc index 581b53761f..939d752dbf 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_split_test.cc @@ -86,7 +86,7 @@ TEST_F(TestHWBnSplit, test_bn_split_tbe) { builder.SetKernelType(KernelType::TBE_KERNEL); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), bn.get()); - // do bn_grad_split_pass + // do bn_split_pass auto optimizer = std::make_shared(); auto pm = std::make_shared(); auto pass = std::make_shared(); @@ -97,5 +97,54 @@ TEST_F(TestHWBnSplit, test_bn_split_tbe) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_bn_split_tbe", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWBnSplit, test_sync_bn_split_tbe) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_sync_bn_split_tbe", "before"); + ASSERT_TRUE(g != nullptr); + std::vector shp_x{1, 64, 112, 112}; + std::vector shp_b{64}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + auto b_abstract = std::make_shared(kFloat32, shp_b); + AbstractBasePtrList args_spec_list{x_abstract, b_abstract, b_abstract, b_abstract, b_abstract}; + auto kernel_graph = GetKernelGraph(g, args_spec_list); + + // get kernel + auto ret = kernel_graph->get_return(); + EXPECT_NE(ret, nullptr); + EXPECT_TRUE(ret->inputs().size() == 2); + auto make_tuple = ret->input(1)->cast(); + EXPECT_NE(make_tuple, nullptr); + EXPECT_TRUE(make_tuple->inputs().size() == 2); + auto item0 = make_tuple->input(1)->cast(); + EXPECT_NE(item0, nullptr); + EXPECT_TRUE(item0->inputs().size() == 3); + auto bn = item0->input(1); + EXPECT_NE(bn, nullptr); + EXPECT_TRUE(bn->isa()); + + // set kernel for SyncBN + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetInputsFormat( + {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); + builder.SetOutputsFormat( + {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); + builder.SetInputsDeviceType( + {kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32}); + builder.SetOutputsDeviceType( + {kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32}); + builder.SetKernelType(KernelType::TBE_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), bn.get()); + + // do sync_bn_split_pass + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pass = std::make_shared(); + pm->AddPass(pass); + optimizer->AddPassManager(pm); + auto new_graph = optimizer->Optimize(kernel_graph); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_sync_bn_split_tbe", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/bn_grad_split.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/bn_grad_split.py index d48c34895a..eb9a69e4b2 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/bn_grad_split.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/bn_grad_split.py @@ -16,15 +16,21 @@ from mindspore.ops import Primitive from mindspore.ops.operations import _grad_ops as G from mindspore.ops import _constants as Constants +from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype make_tuple = Primitive('make_tuple') tuple_getitem = Primitive(Constants.kTupleGetItem) bn_grad = G.BatchNormGrad(is_training=True) +sync_bn_grad = G.SyncBatchNormGrad() bn_grad1 = Primitive('BNGrad1') bn_grad2 = Primitive('BNGrad2') bn_grad3 = Primitive('BNGrad3') bn_training_update_grad = Primitive('BNTrainingUpdateGrad') bn_training_reduce_grad = Primitive('BNTrainingReduceGrad') +allreduce = Primitive('AllReduce') +mul = Primitive('Mul') +mul_value = Tensor(0.5, mstype.float32) class FnDict: @@ -85,3 +91,36 @@ def test_bn_grad_split(tag): return make_tuple(output) return fns[tag] + + +def test_sync_bn_grad_split(tag): + """ test_sync_bn_grad_split """ + fns = FnDict() + + @fns + def before(i0, i1, i2, i3, i4): + bn_grad_output = sync_bn_grad(i0, i1, i2, i3, i4) + item0 = tuple_getitem(bn_grad_output, 0) + item1 = tuple_getitem(bn_grad_output, 1) + item2 = tuple_getitem(bn_grad_output, 2) + output = make_tuple(item0, item1, item2) + return output + + @fns + def after(i0, i1, i2, i3, i4): + bn_update_grad_output = bn_training_update_grad(i0, i1, i3, i4) + update_output0 = tuple_getitem(bn_update_grad_output, 0) + update_output1 = tuple_getitem(bn_update_grad_output, 1) + allreduce_output0 = allreduce(update_output0) + allreduce_output1 = allreduce(update_output1) + update_item0 = mul(allreduce_output0, mul_value) + update_item1 = mul(allreduce_output1, mul_value) + bn_reduce_grad_output = bn_training_reduce_grad(i0, i1, update_item0, update_item1, i2, i3, i4) + output = make_tuple(bn_reduce_grad_output, update_item0, update_item1) + item0 = tuple_getitem(output, 0) + item1 = tuple_getitem(output, 1) + item2 = tuple_getitem(output, 2) + output = make_tuple(item0, item1, item2) + return make_tuple(output) + + return fns[tag] diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/bn_split.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/bn_split.py index d27bb2890b..25e4dbe031 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/bn_split.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/bn_split.py @@ -15,16 +15,23 @@ from mindspore.ops import Primitive from mindspore.ops import operations as P +from mindspore.ops.operations import _inner_ops as inner from mindspore.ops import _constants as Constants +from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype make_tuple = Primitive('make_tuple') tuple_getitem = Primitive(Constants.kTupleGetItem) bn = P.BatchNorm(is_training=True) +sync_bn = inner.SyncBatchNorm() fused_bn1 = Primitive('FusedBN1') fused_bn2 = Primitive('FusedBN2') fused_bn3 = Primitive('FusedBN3') bn_training_reduce = Primitive('BNTrainingReduce') bn_training_update = Primitive('BNTrainingUpdate') +allreduce = Primitive('AllReduce') +mul = Primitive('Mul') +mul_value = Tensor(0.5, mstype.float32) class FnDict: @@ -89,3 +96,30 @@ def test_bn_split_tbe(tag): return make_tuple(output) return fns[tag] + + +def test_sync_bn_split_tbe(tag): + """ test_sync_split_bn_fusion """ + fns = FnDict() + + @fns + def before(x, scale, b, mean, variance): + bn_output = sync_bn(x, scale, b, mean, variance) + output = tuple_getitem(bn_output, 0) + return output + + @fns + def after(x, scale, b, mean, variance): + bn_training_reduce_output = bn_training_reduce(x) + bn_training_reduce_output0 = tuple_getitem(bn_training_reduce_output, 0) + bn_training_reduce_output1 = tuple_getitem(bn_training_reduce_output, 1) + allreduce_output0 = allreduce(bn_training_reduce_output0) + allreduce_output1 = allreduce(bn_training_reduce_output1) + bn_training_update_input1 = mul(allreduce_output0, mul_value) + bn_training_update_input2 = mul(allreduce_output1, mul_value) + bn_training_update_output = bn_training_update(x, bn_training_update_input1, bn_training_update_input2, + scale, b, mean, variance) + output = tuple_getitem(bn_training_update_output, 0) + return make_tuple(output) + + return fns[tag] diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index d525262046..2a86fb1249 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1755,6 +1755,16 @@ test_case_nn_ops = [ 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'skip': ['backward']}), + ('SyncBatchNorm', { + 'block': inner.SyncBatchNorm(), + 'desc_inputs': [[128, 64, 32, 32], [64], [64], [64], [64]], + 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], + 'skip': []}), + ('SyncBatchNormGrad', { + 'block': G.SyncBatchNormGrad(), + 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], + 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], + 'skip': ['backward']}), ('TopK', { 'block': P.TopK(), 'desc_const': [5],