|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Loss scale cell for loss scale training."""
|
|
|
|
|
import mindspore.context as context
|
|
|
|
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
|
|
|
|
from mindspore.train.parallel_utils import ParallelMode
|
|
|
|
|
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
|
|
|
|
@ -34,6 +35,13 @@ reciprocal = P.Reciprocal()
|
|
|
|
|
def tensor_grad_scale(scale, grad):
|
|
|
|
|
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
|
|
|
|
|
|
|
|
|
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
|
|
|
|
grad_overflow = P.FloatStatus()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_grad_overflow.register("Tensor")
|
|
|
|
|
def _tensor_grad_overflow(grad):
|
|
|
|
|
return grad_overflow(grad)
|
|
|
|
|
|
|
|
|
|
class DynamicLossScaleUpdateCell(Cell):
|
|
|
|
|
r"""
|
|
|
|
@ -197,9 +205,15 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|
|
|
|
self.optimizer = optimizer
|
|
|
|
|
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
|
|
|
|
self.hyper_map = C.HyperMap()
|
|
|
|
|
self.alloc_status = NPUAllocFloatStatus()
|
|
|
|
|
self.get_status = NPUGetFloatStatus()
|
|
|
|
|
self.clear_status = NPUClearFloatStatus()
|
|
|
|
|
if context.get_context("device_target") == "GPU":
|
|
|
|
|
self.gpu_target = True
|
|
|
|
|
self.float_status = P.FloatStatus()
|
|
|
|
|
self.addn = P.AddN()
|
|
|
|
|
else:
|
|
|
|
|
self.gpu_target = False
|
|
|
|
|
self.alloc_status = NPUAllocFloatStatus()
|
|
|
|
|
self.get_status = NPUGetFloatStatus()
|
|
|
|
|
self.clear_status = NPUClearFloatStatus()
|
|
|
|
|
self.reduce_sum = ReduceSum(keep_dims=False)
|
|
|
|
|
self.base = Tensor(1, mstype.float32)
|
|
|
|
|
self.less_equal = LessEqual()
|
|
|
|
@ -224,10 +238,12 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|
|
|
|
def construct(self, data, label, sens=None):
|
|
|
|
|
weights = self.weights
|
|
|
|
|
loss = self.network(data, label)
|
|
|
|
|
# init overflow buffer
|
|
|
|
|
init = self.alloc_status()
|
|
|
|
|
# clear overflow buffer
|
|
|
|
|
self.clear_status(init)
|
|
|
|
|
init = False
|
|
|
|
|
if not self.gpu_target:
|
|
|
|
|
# init overflow buffer
|
|
|
|
|
init = self.alloc_status()
|
|
|
|
|
# clear overflow buffer
|
|
|
|
|
self.clear_status(init)
|
|
|
|
|
if sens is None:
|
|
|
|
|
scaling_sens = self.loss_scale
|
|
|
|
|
else:
|
|
|
|
@ -238,9 +254,13 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|
|
|
|
# apply grad reducer on grads
|
|
|
|
|
grads = self.grad_reducer(grads)
|
|
|
|
|
# get the overflow buffer
|
|
|
|
|
self.get_status(init)
|
|
|
|
|
# sum overflow buffer elements, 0:not overflow , >0:overflow
|
|
|
|
|
flag_sum = self.reduce_sum(init, (0,))
|
|
|
|
|
if not self.gpu_target:
|
|
|
|
|
self.get_status(init)
|
|
|
|
|
# sum overflow buffer elements, 0:not overflow , >0:overflow
|
|
|
|
|
flag_sum = self.reduce_sum(init, (0,))
|
|
|
|
|
else:
|
|
|
|
|
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
|
|
|
|
flag_sum = self.addn(flag_sum)
|
|
|
|
|
if self.is_distributed:
|
|
|
|
|
# sum overflow flag over devices
|
|
|
|
|
flag_reduce = self.allreduce(flag_sum)
|
|
|
|
|