|
|
|
@ -29,6 +29,7 @@ from mindspore.communication.management import get_group_size
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore.ops import _selected_ops
|
|
|
|
|
from .bert_model import BertModel
|
|
|
|
|
from .utils import ClipByGlobalNorm
|
|
|
|
|
|
|
|
|
|
GRADIENT_CLIP_TYPE = 1
|
|
|
|
|
GRADIENT_CLIP_VALUE = 1.0
|
|
|
|
@ -348,11 +349,12 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|
|
|
|
optimizer (Optimizer): Optimizer for updating the weights.
|
|
|
|
|
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None):
|
|
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False):
|
|
|
|
|
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.enable_global_norm = enable_global_norm
|
|
|
|
|
self.grad = C.GradOperation(get_by_list=True,
|
|
|
|
|
sens_param=True)
|
|
|
|
|
self.reducer_flag = False
|
|
|
|
@ -419,7 +421,10 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|
|
|
|
# apply grad reducer on grads
|
|
|
|
|
grads = self.grad_reducer(grads)
|
|
|
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
|
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
|
|
|
|
if self.enable_global_norm:
|
|
|
|
|
grads = ClipByGlobalNorm()(grads)
|
|
|
|
|
else:
|
|
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
|
|
|
|
self.get_status(init)
|
|
|
|
|
flag_sum = self.reduce_sum(init, (0,))
|
|
|
|
|
if self.is_distributed:
|
|
|
|
@ -474,12 +479,13 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
|
|
|
|
|
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
|
|
|
|
batch_size * accumulation_steps. Default: 1.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1):
|
|
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
|
|
|
|
|
super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.accumulation_steps = accumulation_steps
|
|
|
|
|
self.enable_global_norm = enable_global_norm
|
|
|
|
|
self.one = Tensor(np.array([1]).astype(np.int32))
|
|
|
|
|
self.zero = Tensor(np.array([0]).astype(np.int32))
|
|
|
|
|
self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step")
|
|
|
|
@ -580,7 +586,10 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
|
|
|
|
|
grads = self.grad_reducer(self.accu_grads)
|
|
|
|
|
scaling = scaling_sens * self.degree * self.accumulation_steps
|
|
|
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
|
|
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
|
|
|
|
if self.enable_global_norm:
|
|
|
|
|
grads = ClipByGlobalNorm()(grad)
|
|
|
|
|
else:
|
|
|
|
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
|
|
|
|
accu_overflow = self.overflow_reducer(accu_overflow)
|
|
|
|
|
F.control_depend(grads, accu_overflow)
|
|
|
|
|
overflow = self.less_equal(self.base, accu_overflow)
|
|
|
|
|