!12102 fix the validation of Softmax, Tanh, Elu operators.

From: @wangshuide2020
Reviewed-by: @liangchenghui
Signed-off-by: @liangchenghui
pull/12102/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 150dad4e46

@ -175,7 +175,7 @@ class ELU(Cell):
ValueError: If `alpha` is not equal to 1.0.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32)

@ -172,7 +172,7 @@ class Softmax(PrimitiveWithInfer):
return logits
def infer_dtype(self, logits):
validator.check_tensor_dtype_valid("logits", logits, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name)
return logits
@ -603,7 +603,7 @@ class Elu(PrimitiveWithInfer):
return input_x
def infer_dtype(self, input_x):
validator.check_tensor_dtype_valid('input_x', input_x, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid('input_x', input_x, mstype.float_type, self.name)
return input_x
@ -761,7 +761,7 @@ class Tanh(PrimitiveWithInfer):
TypeError: If dtype of `input_x` is neither float16 nor float32.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
@ -779,7 +779,7 @@ class Tanh(PrimitiveWithInfer):
return input_x
def infer_dtype(self, input_x):
validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("input_x", input_x, mstype.float_type, self.name)
return input_x

Loading…
Cancel
Save