|
|
|
@ -172,7 +172,7 @@ class Softmax(PrimitiveWithInfer):
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, logits):
|
|
|
|
|
validator.check_tensor_dtype_valid("logits", logits, (mstype.float16, mstype.float32), self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name)
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -603,7 +603,7 @@ class Elu(PrimitiveWithInfer):
|
|
|
|
|
return input_x
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, input_x):
|
|
|
|
|
validator.check_tensor_dtype_valid('input_x', input_x, (mstype.float16, mstype.float32), self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid('input_x', input_x, mstype.float_type, self.name)
|
|
|
|
|
return input_x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -761,7 +761,7 @@ class Tanh(PrimitiveWithInfer):
|
|
|
|
|
TypeError: If dtype of `input_x` is neither float16 nor float32.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend`` ``GPU``
|
|
|
|
|
``Ascend`` ``GPU`` ``CPU``
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
|
|
|
|
@ -779,7 +779,7 @@ class Tanh(PrimitiveWithInfer):
|
|
|
|
|
return input_x
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, input_x):
|
|
|
|
|
validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("input_x", input_x, mstype.float_type, self.name)
|
|
|
|
|
return input_x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|