From: @jiangzg001
Reviewed-by: @liangchenghui,@wuxuejian,@liangchenghui
Signed-off-by: @liangchenghui,@liangchenghui
pull/10352/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 3d81b1bfe5

@ -352,6 +352,13 @@ def _is_float_dtype(dtype):
return False
@constexpr
def _need_reduce_all(axis):
if axis == ():
return True
return False
class ClipByNorm(Cell):
r"""
Clips tensor values to a maximum :math:`L_2`-norm.
@ -424,7 +431,7 @@ class ClipByNorm(Cell):
intermediate = x * clip_norm
max_norm = self.max_op(l2norm, clip_norm)
if self.axis is None:
if _need_reduce_all(self.axis):
max_norm = self.expand_dims(max_norm, -1)
values_clip = self.cast(intermediate, mstype.float32) / max_norm
values_clip = self.reshape(values_clip, self.shape(x))

@ -15,7 +15,7 @@
"""embedding"""
import mindspore.common.dtype as mstype
from mindspore import log as logger
from mindspore.common.tensor import Tensor
from mindspore.common.tensor import Tensor, MetaTensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
@ -101,8 +101,11 @@ class Embedding(Cell):
if padding_idx is not None:
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
"padding_idx", self.cls_name)
self.init_tensor = self.init_tensor.to_tensor().asnumpy()
if isinstance(self.init_tensor, MetaTensor):
self.init_tensor = self.init_tensor.to_tensor()
self.init_tensor = self.init_tensor.asnumpy()
self.init_tensor[self.padding_idx] = 0
self.init_tensor = Tensor(self.init_tensor)
self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
self.expand = P.ExpandDims()
self.reshape_flat = P.Reshape()

@ -77,7 +77,9 @@ def _get_square_sum(x):
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
def _apply_global_norm(clip_norm, global_norm, x):
x_dtype = F.dtype(x)
x = x * clip_norm / global_norm
x = F.cast(x, x_dtype)
return x

@ -6655,7 +6655,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
args = {'init_h': h_dtype, 'bias_input': binput_dtype}
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
b_dtype = binput_dtype
elif bhidden_dtype is not None:
if bhidden_dtype is not None:
args = {'init_h': h_dtype, 'bias_hidden': bhidden_dtype}
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
b_dtype = bhidden_dtype

Loading…
Cancel
Save