diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 441e441c2c..1644c5800a 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -162,8 +162,6 @@ class AllGather(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype == mstype.bool_: - raise TypeError("AllGather does not support 'Bool' as the dtype of input!") return x_dtype def __call__(self, tensor): @@ -221,8 +219,6 @@ class ReduceScatter(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype == mstype.bool_: - raise TypeError("ReduceScatter does not support 'Bool' as the dtype of input!") return x_dtype def __call__(self, tensor): @@ -280,8 +276,6 @@ class Broadcast(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype == mstype.bool_: - raise TypeError("Broadcast does not support 'Bool' as the dtype of input!") return x_dtype @@ -324,8 +318,6 @@ class _AlltoAll(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype == mstype.bool_: - raise TypeError("AlltoAll does not support 'Bool' as the dtype of input!") return x_dtype def __call__(self, tensor):