|
|
|
@ -29,7 +29,6 @@ from mindspore.ops.primitive import constexpr, Primitive
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore._extends import cell_attr_register
|
|
|
|
|
from mindspore._checkparam import Rel, Validator
|
|
|
|
|
from mindspore.common.api import ms_function
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from ..cell import Cell
|
|
|
|
|
from .activation import get_activation
|
|
|
|
@ -413,9 +412,7 @@ class ClipByNorm(Cell):
|
|
|
|
|
self.expand_dims = P.ExpandDims()
|
|
|
|
|
self.dtype = P.DType()
|
|
|
|
|
|
|
|
|
|
@ms_function
|
|
|
|
|
def construct(self, x, clip_norm):
|
|
|
|
|
"""add ms_function decorator for pynative mode"""
|
|
|
|
|
mul_x = F.square(x)
|
|
|
|
|
l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32)
|
|
|
|
|
cond = self.greater_(l2sum, 0)
|
|
|
|
|