|
|
|
@ -1712,6 +1712,7 @@ class FloatStatus(PrimitiveWithInfer):
|
|
|
|
|
return [1]
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32, mstype.float16], self.name)
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
class NPUAllocFloatStatus(PrimitiveWithInfer):
|
|
|
|
|