|
|
|
@ -22,6 +22,7 @@ from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
|
from mindspore.ops.functional import identity
|
|
|
|
|
from mindspore.ops.operations import _inner_ops as inner
|
|
|
|
|
from mindspore.ops.primitive import constexpr
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore._extends import cell_attr_register
|
|
|
|
|
from mindspore.common.api import ms_function
|
|
|
|
@ -236,6 +237,13 @@ class Dense(Cell):
|
|
|
|
|
return str_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _is_equal_one(x):
|
|
|
|
|
if x is None:
|
|
|
|
|
return False
|
|
|
|
|
return bool(x.asnumpy().mean() == 1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClipByNorm(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
Clips tensor values to a maximum :math:`L_2`-norm.
|
|
|
|
@ -290,7 +298,10 @@ class ClipByNorm(Cell):
|
|
|
|
|
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
|
|
|
|
|
l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
|
|
|
|
|
|
|
|
|
|
intermediate = x * clip_norm
|
|
|
|
|
if _is_equal_one(clip_norm):
|
|
|
|
|
intermediate = x
|
|
|
|
|
else:
|
|
|
|
|
intermediate = x * clip_norm
|
|
|
|
|
max_norm = self.max_op(l2norm, clip_norm)
|
|
|
|
|
values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1)
|
|
|
|
|
values_clip = self.reshape(values_clip, self.shape(x))
|
|
|
|
|