diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 2449eea9b4..2f8b38e818 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -23,6 +23,7 @@ from mindspore.ops import functional as F from mindspore.ops.functional import identity from mindspore.common.parameter import Parameter from mindspore._extends import cell_attr_register +from mindspore.common.api import ms_function from ..cell import Cell from .activation import get_activation from ..._checkparam import Validator as validator @@ -261,7 +262,9 @@ class ClipByNorm(Cell): self.expand_dims = P.ExpandDims() self.dtype = P.DType() + @ms_function def construct(self, x, clip_norm): + """add ms_function decorator for pynative mode""" mul_x = F.square(x) l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32) cond = self.greater_(l2sum, self.zero)