From 8da6d65222cb7d828aaa8eced60b84c536e44172 Mon Sep 17 00:00:00 2001 From: wangshuide2020 <7511764+wangshuide2020@user.noreply.gitee.com> Date: Thu, 4 Feb 2021 14:49:04 +0800 Subject: [PATCH] fix the validation of Softmax, Tanh, Elu operators. --- mindspore/nn/layer/activation.py | 2 +- mindspore/ops/operations/nn_ops.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index 3501144e4f..421d25524f 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -175,7 +175,7 @@ class ELU(Cell): ValueError: If `alpha` is not equal to 1.0. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index b57a3efa6e..a183dbda9d 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -172,7 +172,7 @@ class Softmax(PrimitiveWithInfer): return logits def infer_dtype(self, logits): - validator.check_tensor_dtype_valid("logits", logits, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name) return logits @@ -603,7 +603,7 @@ class Elu(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_tensor_dtype_valid('input_x', input_x, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid('input_x', input_x, mstype.float_type, self.name) return input_x @@ -761,7 +761,7 @@ class Tanh(PrimitiveWithInfer): TypeError: If dtype of `input_x` is neither float16 nor float32. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32) @@ -779,7 +779,7 @@ class Tanh(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("input_x", input_x, mstype.float_type, self.name) return input_x