|
|
|
|
@ -45,7 +45,6 @@ class AllReduce(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
The operation of AllReduce does not support "prod" currently.
|
|
|
|
|
The input of AllReduce does not support dtype "Bool".
|
|
|
|
|
Tensor must have same shape and format in all processes participating in the collective.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
@ -103,7 +102,7 @@ class AllReduce(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
if x_dtype == mstype.bool_:
|
|
|
|
|
if x_dtype.element_type() == mstype.bool_:
|
|
|
|
|
raise TypeError("AllReduce does not support 'Bool' as the dtype of input!")
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
@ -161,7 +160,7 @@ class AllGather(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
if x_dtype == mstype.bool_:
|
|
|
|
|
if x_dtype.element_type() == mstype.bool_:
|
|
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
@ -218,7 +217,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
if x_dtype == mstype.bool_:
|
|
|
|
|
if x_dtype.element_type() == mstype.bool_:
|
|
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
@ -275,11 +274,13 @@ class Broadcast(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
if x_dtype == mstype.bool_:
|
|
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
|
|
|
|
if not isinstance(x_dtype, tuple):
|
|
|
|
|
raise TypeError(f"{self.name}'s input should be a tuple!")
|
|
|
|
|
for _ele in x_dtype:
|
|
|
|
|
if _ele.element_type() == mstype.bool_:
|
|
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _AlltoAll(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
AlltoAll is a collective operation.
|
|
|
|
|
@ -318,7 +319,7 @@ class _AlltoAll(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
if x_dtype == mstype.bool_:
|
|
|
|
|
if x_dtype.element_type() == mstype.bool_:
|
|
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|