diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index ccfdbba67b..c4eaa3b12a 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -349,9 +349,9 @@ class TrainOneStepCell(Cell): if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): self.reducer_flag = True if self.reducer_flag: - mean = _get_gradients_mean() - degree = _get_device_num() - self.grad_reducer = DistributedGradReducer(self.weights, mean, degree) + self.mean = _get_gradients_mean() + self.degree = _get_device_num() + self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree) def construct(self, *inputs): weights = self.weights