@ -48,6 +48,9 @@ grad_overflow = P.FloatStatus()
def _tensor_grad_overflow(grad):
return grad_overflow(grad)
@_grad_overflow.register("RowTensor")
def _tensor_grad_overflow_row_tensor(grad):
return grad_overflow(grad.values)
class DynamicLossScaleUpdateCell(Cell):
r"""