From 3390ae36297a45b91f8e24e008d437d7e406af9b Mon Sep 17 00:00:00 2001 From: jiangzhenguang Date: Tue, 22 Dec 2020 20:38:05 +0800 Subject: [PATCH] fix operation --- mindspore/nn/layer/basic.py | 9 ++++++++- mindspore/nn/layer/embedding.py | 7 +++++-- mindspore/ops/composite/clip_ops.py | 2 ++ mindspore/ops/operations/nn_ops.py | 2 +- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 31acf3bc7c..df681d5c54 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -352,6 +352,13 @@ def _is_float_dtype(dtype): return False +@constexpr +def _need_reduce_all(axis): + if axis == (): + return True + return False + + class ClipByNorm(Cell): r""" Clips tensor values to a maximum :math:`L_2`-norm. @@ -424,7 +431,7 @@ class ClipByNorm(Cell): intermediate = x * clip_norm max_norm = self.max_op(l2norm, clip_norm) - if self.axis is None: + if _need_reduce_all(self.axis): max_norm = self.expand_dims(max_norm, -1) values_clip = self.cast(intermediate, mstype.float32) / max_norm values_clip = self.reshape(values_clip, self.shape(x)) diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index e8361e6b19..024432f3c0 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -15,7 +15,7 @@ """embedding""" import mindspore.common.dtype as mstype from mindspore import log as logger -from mindspore.common.tensor import Tensor +from mindspore.common.tensor import Tensor, MetaTensor from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.common.parameter import Parameter @@ -101,8 +101,11 @@ class Embedding(Cell): if padding_idx is not None: self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH, "padding_idx", self.cls_name) - self.init_tensor = self.init_tensor.to_tensor().asnumpy() + if isinstance(self.init_tensor, MetaTensor): + self.init_tensor = self.init_tensor.to_tensor() + self.init_tensor = self.init_tensor.asnumpy() self.init_tensor[self.padding_idx] = 0 + self.init_tensor = Tensor(self.init_tensor) self.embedding_table = Parameter(self.init_tensor, name='embedding_table') self.expand = P.ExpandDims() self.reshape_flat = P.Reshape() diff --git a/mindspore/ops/composite/clip_ops.py b/mindspore/ops/composite/clip_ops.py index b68b5f26c1..7836e23599 100644 --- a/mindspore/ops/composite/clip_ops.py +++ b/mindspore/ops/composite/clip_ops.py @@ -77,7 +77,9 @@ def _get_square_sum(x): apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") @apply_global_norm.register("Tensor", "Tensor", "Tensor") def _apply_global_norm(clip_norm, global_norm, x): + x_dtype = F.dtype(x) x = x * clip_norm / global_norm + x = F.cast(x, x_dtype) return x diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 63b4851995..9a8f48b5bf 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -6638,7 +6638,7 @@ class DynamicGRUV2(PrimitiveWithInfer): args = {'init_h': h_dtype, 'bias_input': binput_dtype} validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) b_dtype = binput_dtype - elif bhidden_dtype is not None: + if bhidden_dtype is not None: args = {'init_h': h_dtype, 'bias_hidden': bhidden_dtype} validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) b_dtype = bhidden_dtype