From 9d8b14d942415acd9a6e22f283377f7e810636a3 Mon Sep 17 00:00:00 2001 From: riemann_penn Date: Thu, 17 Sep 2020 20:06:46 +0800 Subject: [PATCH] fix sparse loss scale --- mindspore/nn/wrap/loss_scale.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index d10b7984fb..729119db9c 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -48,6 +48,9 @@ grad_overflow = P.FloatStatus() def _tensor_grad_overflow(grad): return grad_overflow(grad) +@_grad_overflow.register("RowTensor") +def _tensor_grad_overflow_row_tensor(grad): + return grad_overflow(grad.values) class DynamicLossScaleUpdateCell(Cell): r"""