|
|
|
@ -17,10 +17,11 @@ from types import FunctionType, MethodType
|
|
|
|
|
|
|
|
|
|
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
|
|
|
|
_get_parallel_mode)
|
|
|
|
|
from mindspore.context import ParallelMode
|
|
|
|
|
from mindspore.context import ParallelMode, get_auto_parallel_context
|
|
|
|
|
from mindspore._checkparam import Validator as validator
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ...common.parameter import Parameter, ParameterTuple
|
|
|
|
|
from ...common.tensor import Tensor
|
|
|
|
|
from ...ops import composite as C
|
|
|
|
|
from ...ops import functional as F
|
|
|
|
|
from ...ops import operations as P
|
|
|
|
@ -62,6 +63,19 @@ def _tensors_cast_datatype(datatype, param):
|
|
|
|
|
"""
|
|
|
|
|
return F.cast(param, datatype)
|
|
|
|
|
|
|
|
|
|
_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
|
|
|
|
|
|
|
|
|
|
@_gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
|
|
|
|
|
def _cumulative_grad(accumulation_step, cumulative_grad, grad):
|
|
|
|
|
"""Apply gradient accumulation to cumulative grad."""
|
|
|
|
|
return P.AssignAdd()(cumulative_grad, grad / accumulation_step)
|
|
|
|
|
|
|
|
|
|
_gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")
|
|
|
|
|
|
|
|
|
|
@_gradient_clear_op.register("Tensor")
|
|
|
|
|
def _clear_grad(cumulative_grad):
|
|
|
|
|
zero_grad = P.ZerosLike()(cumulative_grad)
|
|
|
|
|
return F.assign(cumulative_grad, zero_grad)
|
|
|
|
|
|
|
|
|
|
class WithLossCell(Cell):
|
|
|
|
|
r"""
|
|
|
|
@ -347,14 +361,27 @@ class TrainOneStepCell(Cell):
|
|
|
|
|
self.mean = _get_gradients_mean()
|
|
|
|
|
self.degree = _get_device_num()
|
|
|
|
|
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
|
|
|
|
|
self.use_grad_accumulation = False
|
|
|
|
|
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE):
|
|
|
|
|
self.use_grad_accumulation = True
|
|
|
|
|
if self.use_grad_accumulation:
|
|
|
|
|
self.max_accumulation_step = get_auto_parallel_context("grad_accumulation_step")
|
|
|
|
|
if self.max_accumulation_step <= 1:
|
|
|
|
|
self.max_accumulation_step = 1
|
|
|
|
|
self.use_grad_accumulation = False
|
|
|
|
|
if self.use_grad_accumulation:
|
|
|
|
|
self.grad_accumulation = GradientAccumulation(self.max_accumulation_step, self.optimizer)
|
|
|
|
|
|
|
|
|
|
def construct(self, *inputs):
|
|
|
|
|
weights = self.weights
|
|
|
|
|
loss = self.network(*inputs)
|
|
|
|
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
|
|
|
|
sens = F.depend(sens, loss)
|
|
|
|
|
grads = self.grad(self.network, weights)(*inputs, sens)
|
|
|
|
|
grads = self.grad_reducer(grads)
|
|
|
|
|
|
|
|
|
|
if self.use_grad_accumulation:
|
|
|
|
|
loss = self.grad_accumulation(loss, grads)
|
|
|
|
|
else:
|
|
|
|
|
loss = F.depend(loss, self.optimizer(grads))
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
@ -557,3 +584,34 @@ class _BroadCastCell(Cell):
|
|
|
|
|
params = self.broadcast(params)
|
|
|
|
|
new_params = self.map_(F.partial(_cast_datatype), datatypes, params)
|
|
|
|
|
return new_params
|
|
|
|
|
|
|
|
|
|
class GradientAccumulation(Cell):
|
|
|
|
|
"""
|
|
|
|
|
After accumulating the gradients of multiple steps, call to optimize its update.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
max_accumulation_step (int): Steps to accumulate gradients.
|
|
|
|
|
optimizer(Cell):Optimizer used.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, max_accumulation_step, optimizer):
|
|
|
|
|
super(GradientAccumulation, self).__init__()
|
|
|
|
|
self._max_accumulation_step = max_accumulation_step
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.weights = optimizer.parameters
|
|
|
|
|
self.hyper_map = C.HyperMap()
|
|
|
|
|
self._grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
|
|
|
|
|
self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
|
|
|
|
|
|
|
|
|
|
def construct(self, loss, grads):
|
|
|
|
|
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_accumulation_op, self._max_accumulation_step),
|
|
|
|
|
self._grad_accumulation, grads))
|
|
|
|
|
self._accumulation_step += 1
|
|
|
|
|
|
|
|
|
|
if self._accumulation_step >= self._max_accumulation_step:
|
|
|
|
|
loss = F.depend(loss, self.optimizer(self._grad_accumulation))
|
|
|
|
|
self._accumulation_step = 0
|
|
|
|
|
|
|
|
|
|
if self._accumulation_step == 0:
|
|
|
|
|
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_clear_op), self._grad_accumulation))
|
|
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|