gpu floatstatus add type check

pull/1222/head
VectorSL 5 years ago
parent f23bfe0d71
commit 2b51199054

@ -1712,6 +1712,7 @@ class FloatStatus(PrimitiveWithInfer):
return [1]
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32, mstype.float16], self.name)
return x_dtype
class NPUAllocFloatStatus(PrimitiveWithInfer):

Loading…
Cancel
Save