From a2850cae327f5d8d65ad1ac59b420139b0d85686 Mon Sep 17 00:00:00 2001 From: suteng Date: Sat, 11 Apr 2020 15:30:24 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!231=20:?= =?UTF-8?q?=20add=20bool=20type=20check=20in=20communication=20operator=20?= =?UTF-8?q?'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspore/ops/operations/comm_ops.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 441e441c2c..1644c5800a 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -162,8 +162,6 @@ class AllGather(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype == mstype.bool_: - raise TypeError("AllGather does not support 'Bool' as the dtype of input!") return x_dtype def __call__(self, tensor): @@ -221,8 +219,6 @@ class ReduceScatter(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype == mstype.bool_: - raise TypeError("ReduceScatter does not support 'Bool' as the dtype of input!") return x_dtype def __call__(self, tensor): @@ -280,8 +276,6 @@ class Broadcast(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype == mstype.bool_: - raise TypeError("Broadcast does not support 'Bool' as the dtype of input!") return x_dtype @@ -324,8 +318,6 @@ class _AlltoAll(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype == mstype.bool_: - raise TypeError("AlltoAll does not support 'Bool' as the dtype of input!") return x_dtype def __call__(self, tensor):