From 7578f23632767abf6fe4c9f745b5967ffd8b6612 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Mon, 14 Sep 2020 11:44:23 +0800 Subject: [PATCH] optim TrainOneStepCell --- mindspore/nn/wrap/cell_wrapper.py | 12 +++++----- mindspore/nn/wrap/loss_scale.py | 21 +++--------------- .../nlp/bert/src/bert_for_pre_training.py | 22 ++----------------- .../bert_thor/src/bert_for_pre_training.py | 22 ++----------------- 4 files changed, 12 insertions(+), 65 deletions(-) diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 6d25daf41e..c1bb9429ef 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -185,23 +185,21 @@ class TrainOneStepCell(Cell): self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.sens = sens self.reducer_flag = False - self.grad_reducer = None - parallel_mode = _get_parallel_mode() - if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): + self.grad_reducer = F.identity + self.parallel_mode = _get_parallel_mode() + if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): self.reducer_flag = True if self.reducer_flag: mean = _get_gradients_mean() degree = _get_device_num() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + self.grad_reducer = DistributedGradReducer(self.weights, mean, degree) def construct(self, *inputs): weights = self.weights loss = self.network(*inputs) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*inputs, sens) - if self.reducer_flag: - # apply grad reducer on grads - grads = self.grad_reducer(grads) + grads = self.grad_reducer(grads) return F.depend(loss, self.optimizer(grads)) diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index dce621a765..b1333117e2 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -14,9 +14,8 @@ # ============================================================================ """Loss scale cell for loss scale training.""" import mindspore.context as context -from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.context import ParallelMode -from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean +from .cell_wrapper import TrainOneStepCell from ..cell import Cell from ...common import Tensor, RowTensor from ...common.parameter import Parameter @@ -163,7 +162,7 @@ class FixedLossScaleUpdateCell(Cell): return overflow -class TrainOneStepWithLossScaleCell(Cell): +class TrainOneStepWithLossScaleCell(TrainOneStepCell): r""" Network training with loss scaling. @@ -203,15 +202,8 @@ class TrainOneStepWithLossScaleCell(Cell): >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32) >>> output = train_network(inputs, label, scaling_sens) """ - def __init__(self, network, optimizer, scale_sense): - super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.network.add_flags(defer_inline=True) - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, sens_param=True) + super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None) self.hyper_map = C.HyperMap() if context.get_context("device_target") == "GPU": self.gpu_target = True @@ -228,13 +220,6 @@ class TrainOneStepWithLossScaleCell(Cell): self.less_equal = LessEqual() self.depend_parameter_use = ControlDepend(depend_mode=1) self.allreduce = P.AllReduce() - self.parallel_mode = _get_parallel_mode() - self.grad_reducer = F.identity - self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL] - if self.reducer_flag: - mean = _get_gradients_mean() - degree = _get_device_num() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE self.loss_scaling_manager = None diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index eb19a5c88f..4459d9842c 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -271,23 +271,7 @@ class BertTrainOneStepCell(nn.Cell): sens (Number): The adjust parameter. Default: 1.0. """ def __init__(self, network, optimizer, sens=1.0): - super(BertTrainOneStepCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, sens_param=True) - self.sens = sens - self.reducer_flag = False - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("gradients_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - + super(BertTrainOneStepCell, self).__init__(network, optimizer, sens) self.cast = P.Cast() self.hyper_map = C.HyperMap() @@ -322,9 +306,7 @@ class BertTrainOneStepCell(nn.Cell): self.cast(F.tuple_to_array((self.sens,)), mstype.float32)) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - if self.reducer_flag: - # apply grad reducer on grads - grads = self.grad_reducer(grads) + grads = self.grad_reducer(grads) succ = self.optimizer(grads) return F.depend(loss, succ) diff --git a/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py index cf8e6214bd..632f8825bd 100644 --- a/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py @@ -289,23 +289,7 @@ class BertTrainOneStepCell(nn.Cell): """ def __init__(self, network, optimizer, sens=1.0): - super(BertTrainOneStepCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, sens_param=True) - self.sens = sens - self.reducer_flag = False - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("gradients_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - + super(BertTrainOneStepCell, self).__init__(network, optimizer, sens) self.cast = P.Cast() self.hyper_map = C.HyperMap() @@ -340,9 +324,7 @@ class BertTrainOneStepCell(nn.Cell): self.cast(F.tuple_to_array((self.sens,)), mstype.float32)) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - if self.reducer_flag: - # apply grad reducer on grads - grads = self.grad_reducer(grads) + grads = self.grad_reducer(grads) succ = self.optimizer(grads) return F.depend(loss, succ)