|
|
|
@ -45,7 +45,6 @@ class AllReduce(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
The operation of AllReduce does not support "prod" currently.
|
|
|
|
|
The input of AllReduce does not support dtype "Bool".
|
|
|
|
|
Tensor must have same shape and format in all processes participating in the collective.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -103,7 +102,7 @@ class AllReduce(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
if x_dtype == mstype.bool_:
|
|
|
|
|
if x_dtype.element_type() == mstype.bool_:
|
|
|
|
|
raise TypeError("AllReduce does not support 'Bool' as the dtype of input!")
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
@ -161,7 +160,7 @@ class AllGather(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
if x_dtype == mstype.bool_:
|
|
|
|
|
if x_dtype.element_type() == mstype.bool_:
|
|
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
@ -176,6 +175,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
|
|
|
|
Note:
|
|
|
|
|
The back propagation of the op is not surported yet. Stay tuned for more.
|
|
|
|
|
Tensor must have the same shape and format in all processes participating in the collective.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
op (str): Specifies an operation used for element-wise reductions,
|
|
|
|
|
like sum, max, avg. Default: ReduceOp.SUM.
|
|
|
|
@ -218,7 +218,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
if x_dtype == mstype.bool_:
|
|
|
|
|
if x_dtype.element_type() == mstype.bool_:
|
|
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
@ -275,8 +275,11 @@ class Broadcast(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
if x_dtype == mstype.bool_:
|
|
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
|
|
|
|
if not isinstance(x_dtype, tuple):
|
|
|
|
|
raise TypeError(f"{self.name}'s input should be a tuple!")
|
|
|
|
|
for _ele in x_dtype:
|
|
|
|
|
if _ele.element_type() == mstype.bool_:
|
|
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -318,7 +321,7 @@ class _AlltoAll(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
if x_dtype == mstype.bool_:
|
|
|
|
|
if x_dtype.element_type() == mstype.bool_:
|
|
|
|
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|