diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py
index 6d25daf41e..c1bb9429ef 100644
--- a/mindspore/nn/wrap/cell_wrapper.py
+++ b/mindspore/nn/wrap/cell_wrapper.py
@@ -185,23 +185,21 @@ class TrainOneStepCell(Cell):
         self.grad = C.GradOperation(get_by_list=True, sens_param=True)
         self.sens = sens
         self.reducer_flag = False
-        self.grad_reducer = None
-        parallel_mode = _get_parallel_mode()
-        if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
+        self.grad_reducer = F.identity
+        self.parallel_mode = _get_parallel_mode()
+        if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
             self.reducer_flag = True
         if self.reducer_flag:
             mean = _get_gradients_mean()
             degree = _get_device_num()
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
+            self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
 
     def construct(self, *inputs):
         weights = self.weights
         loss = self.network(*inputs)
         sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
         grads = self.grad(self.network, weights)(*inputs, sens)
-        if self.reducer_flag:
-            # apply grad reducer on grads
-            grads = self.grad_reducer(grads)
+        grads = self.grad_reducer(grads)
         return F.depend(loss, self.optimizer(grads))
 
 
diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py
index dce621a765..b1333117e2 100644
--- a/mindspore/nn/wrap/loss_scale.py
+++ b/mindspore/nn/wrap/loss_scale.py
@@ -14,9 +14,8 @@
 # ============================================================================
 """Loss scale cell for loss scale training."""
 import mindspore.context as context
-from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
 from mindspore.context import ParallelMode
-from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
+from .cell_wrapper import TrainOneStepCell
 from ..cell import Cell
 from ...common import Tensor, RowTensor
 from ...common.parameter import Parameter
@@ -163,7 +162,7 @@ class FixedLossScaleUpdateCell(Cell):
         return overflow
 
 
-class TrainOneStepWithLossScaleCell(Cell):
+class TrainOneStepWithLossScaleCell(TrainOneStepCell):
     r"""
     Network training with loss scaling.
 
@@ -203,15 +202,8 @@ class TrainOneStepWithLossScaleCell(Cell):
         >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
         >>> output = train_network(inputs, label, scaling_sens)
     """
-
     def __init__(self, network, optimizer, scale_sense):
-        super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad()
-        self.network.add_flags(defer_inline=True)
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
+        super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)
         self.hyper_map = C.HyperMap()
         if context.get_context("device_target") == "GPU":
             self.gpu_target = True
@@ -228,13 +220,6 @@ class TrainOneStepWithLossScaleCell(Cell):
         self.less_equal = LessEqual()
         self.depend_parameter_use = ControlDepend(depend_mode=1)
         self.allreduce = P.AllReduce()
-        self.parallel_mode = _get_parallel_mode()
-        self.grad_reducer = F.identity
-        self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
-        if self.reducer_flag:
-            mean = _get_gradients_mean()
-            degree = _get_device_num()
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
         self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
 
         self.loss_scaling_manager = None
diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py
index eb19a5c88f..4459d9842c 100644
--- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py
+++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py
@@ -271,23 +271,7 @@ class BertTrainOneStepCell(nn.Cell):
         sens (Number): The adjust parameter. Default: 1.0.
     """
     def __init__(self, network, optimizer, sens=1.0):
-        super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad()
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
-        self.sens = sens
-        self.reducer_flag = False
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
-            self.reducer_flag = True
-        self.grad_reducer = None
-        if self.reducer_flag:
-            mean = context.get_auto_parallel_context("gradients_mean")
-            degree = get_group_size()
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
-
+        super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
         self.cast = P.Cast()
         self.hyper_map = C.HyperMap()
 
@@ -322,9 +306,7 @@ class BertTrainOneStepCell(nn.Cell):
                                                  self.cast(F.tuple_to_array((self.sens,)),
                                                            mstype.float32))
         grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
-        if self.reducer_flag:
-            # apply grad reducer on grads
-            grads = self.grad_reducer(grads)
+        grads = self.grad_reducer(grads)
         succ = self.optimizer(grads)
         return F.depend(loss, succ)
 
diff --git a/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py
index cf8e6214bd..632f8825bd 100644
--- a/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py
+++ b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py
@@ -289,23 +289,7 @@ class BertTrainOneStepCell(nn.Cell):
     """
 
     def __init__(self, network, optimizer, sens=1.0):
-        super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad()
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
-        self.sens = sens
-        self.reducer_flag = False
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
-            self.reducer_flag = True
-        self.grad_reducer = None
-        if self.reducer_flag:
-            mean = context.get_auto_parallel_context("gradients_mean")
-            degree = get_group_size()
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
-
+        super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
         self.cast = P.Cast()
         self.hyper_map = C.HyperMap()
 
@@ -340,9 +324,7 @@ class BertTrainOneStepCell(nn.Cell):
                                                  self.cast(F.tuple_to_array((self.sens,)),
                                                            mstype.float32))
         grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
-        if self.reducer_flag:
-            # apply grad reducer on grads
-            grads = self.grad_reducer(grads)
+        grads = self.grad_reducer(grads)
         succ = self.optimizer(grads)
         return F.depend(loss, succ)