fix sparse loss scale

pull/6430/head
riemann_penn 4 years ago
parent c55bd78ce5
commit 9d8b14d942

@ -48,6 +48,9 @@ grad_overflow = P.FloatStatus()
def _tensor_grad_overflow(grad):
return grad_overflow(grad)
@_grad_overflow.register("RowTensor")
def _tensor_grad_overflow_row_tensor(grad):
return grad_overflow(grad.values)
class DynamicLossScaleUpdateCell(Cell):
r"""

Loading…
Cancel
Save