From: @jiangzg001
Reviewed-by: @liangchenghui,@wuxuejian,@liangchenghui
Signed-off-by: @liangchenghui,@liangchenghui
pull/10352/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 3d81b1bfe5

@ -352,6 +352,13 @@ def _is_float_dtype(dtype):
return False return False
@constexpr
def _need_reduce_all(axis):
if axis == ():
return True
return False
class ClipByNorm(Cell): class ClipByNorm(Cell):
r""" r"""
Clips tensor values to a maximum :math:`L_2`-norm. Clips tensor values to a maximum :math:`L_2`-norm.
@ -424,7 +431,7 @@ class ClipByNorm(Cell):
intermediate = x * clip_norm intermediate = x * clip_norm
max_norm = self.max_op(l2norm, clip_norm) max_norm = self.max_op(l2norm, clip_norm)
if self.axis is None: if _need_reduce_all(self.axis):
max_norm = self.expand_dims(max_norm, -1) max_norm = self.expand_dims(max_norm, -1)
values_clip = self.cast(intermediate, mstype.float32) / max_norm values_clip = self.cast(intermediate, mstype.float32) / max_norm
values_clip = self.reshape(values_clip, self.shape(x)) values_clip = self.reshape(values_clip, self.shape(x))

@ -15,7 +15,7 @@
"""embedding""" """embedding"""
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import log as logger from mindspore import log as logger
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor, MetaTensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
@ -101,8 +101,11 @@ class Embedding(Cell):
if padding_idx is not None: if padding_idx is not None:
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH, self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
"padding_idx", self.cls_name) "padding_idx", self.cls_name)
self.init_tensor = self.init_tensor.to_tensor().asnumpy() if isinstance(self.init_tensor, MetaTensor):
self.init_tensor = self.init_tensor.to_tensor()
self.init_tensor = self.init_tensor.asnumpy()
self.init_tensor[self.padding_idx] = 0 self.init_tensor[self.padding_idx] = 0
self.init_tensor = Tensor(self.init_tensor)
self.embedding_table = Parameter(self.init_tensor, name='embedding_table') self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
self.expand = P.ExpandDims() self.expand = P.ExpandDims()
self.reshape_flat = P.Reshape() self.reshape_flat = P.Reshape()

@ -77,7 +77,9 @@ def _get_square_sum(x):
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
@apply_global_norm.register("Tensor", "Tensor", "Tensor") @apply_global_norm.register("Tensor", "Tensor", "Tensor")
def _apply_global_norm(clip_norm, global_norm, x): def _apply_global_norm(clip_norm, global_norm, x):
x_dtype = F.dtype(x)
x = x * clip_norm / global_norm x = x * clip_norm / global_norm
x = F.cast(x, x_dtype)
return x return x

@ -6655,7 +6655,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
args = {'init_h': h_dtype, 'bias_input': binput_dtype} args = {'init_h': h_dtype, 'bias_input': binput_dtype}
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
b_dtype = binput_dtype b_dtype = binput_dtype
elif bhidden_dtype is not None: if bhidden_dtype is not None:
args = {'init_h': h_dtype, 'bias_hidden': bhidden_dtype} args = {'init_h': h_dtype, 'bias_hidden': bhidden_dtype}
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
b_dtype = bhidden_dtype b_dtype = bhidden_dtype

Loading…
Cancel
Save