fix the validation of Softmax, Tanh, Elu operators.

pull/12102/head
wangshuide2020 4 years ago
parent 2847a1f3e3
commit 8da6d65222

@ -175,7 +175,7 @@ class ELU(Cell):
ValueError: If `alpha` is not equal to 1.0.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32)

@ -172,7 +172,7 @@ class Softmax(PrimitiveWithInfer):
return logits
def infer_dtype(self, logits):
validator.check_tensor_dtype_valid("logits", logits, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name)
return logits
@ -603,7 +603,7 @@ class Elu(PrimitiveWithInfer):
return input_x
def infer_dtype(self, input_x):
validator.check_tensor_dtype_valid('input_x', input_x, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid('input_x', input_x, mstype.float_type, self.name)
return input_x
@ -761,7 +761,7 @@ class Tanh(PrimitiveWithInfer):
TypeError: If dtype of `input_x` is neither float16 nor float32.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
@ -779,7 +779,7 @@ class Tanh(PrimitiveWithInfer):
return input_x
def infer_dtype(self, input_x):
validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("input_x", input_x, mstype.float_type, self.name)
return input_x

Loading…
Cancel
Save