|
|
|
@ -17,7 +17,7 @@
|
|
|
|
|
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
|
from ...communication.management import get_rank, get_group_size, GlobalComm, get_group
|
|
|
|
|
from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
|
|
|
|
|
|
|
|
@ -88,10 +88,10 @@ class AllReduce(PrimitiveWithInfer):
|
|
|
|
|
raise TypeError("The operation of AllReduce should be str.")
|
|
|
|
|
if op == ReduceOp.PROD:
|
|
|
|
|
raise RuntimeError("The operation of AllReduce 'prod' is not supported yet.")
|
|
|
|
|
if not isinstance(get_group(group), str):
|
|
|
|
|
if not isinstance(_get_group(group), str):
|
|
|
|
|
raise TypeError("The group of AllReduce should be str.")
|
|
|
|
|
self.op = op
|
|
|
|
|
self.add_prim_attr('group', get_group(group))
|
|
|
|
|
self.add_prim_attr('group', _get_group(group))
|
|
|
|
|
self.add_prim_attr('fusion', 0)
|
|
|
|
|
|
|
|
|
|
def vm_impl(self, x):
|
|
|
|
@ -149,12 +149,12 @@ class AllGather(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
validator.check_value_type('group', get_group(group), (str,), self.name)
|
|
|
|
|
self.rank = get_rank(get_group(group))
|
|
|
|
|
self.rank_size = get_group_size(get_group(group))
|
|
|
|
|
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
|
|
|
|
self.rank = get_rank(_get_group(group))
|
|
|
|
|
self.rank_size = get_group_size(_get_group(group))
|
|
|
|
|
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
|
|
|
|
|
self.add_prim_attr('rank_size', self.rank_size)
|
|
|
|
|
self.add_prim_attr('group', get_group(group))
|
|
|
|
|
self.add_prim_attr('group', _get_group(group))
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
|
x_shape[0] = x_shape[0] * self.rank_size
|
|
|
|
@ -205,11 +205,11 @@ class ReduceScatter(PrimitiveWithInfer):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
|
|
|
|
|
validator.check_value_type('group', get_group(group), (str,), self.name)
|
|
|
|
|
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
|
|
|
|
self.op = op
|
|
|
|
|
self.rank_size = get_group_size(get_group(group))
|
|
|
|
|
self.rank_size = get_group_size(_get_group(group))
|
|
|
|
|
self.add_prim_attr('rank_size', self.rank_size)
|
|
|
|
|
self.add_prim_attr('group', get_group(group))
|
|
|
|
|
self.add_prim_attr('group', _get_group(group))
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
|
if x_shape[0] % self.rank_size != 0:
|
|
|
|
@ -268,8 +268,8 @@ class Broadcast(PrimitiveWithInfer):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
validator.check_value_type('root_rank', root_rank, (int,), self.name)
|
|
|
|
|
validator.check_value_type('group', get_group(group), (str,), self.name)
|
|
|
|
|
self.add_prim_attr('group', get_group(group))
|
|
|
|
|
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
|
|
|
|
self.add_prim_attr('group', _get_group(group))
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
|
return x_shape
|
|
|
|
@ -306,11 +306,11 @@ class _AlltoAll(PrimitiveWithInfer):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
"""init AlltoAll"""
|
|
|
|
|
validator.check_value_type('group', get_group(group), (str,), self.name)
|
|
|
|
|
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
|
|
|
|
self.split_count = split_count
|
|
|
|
|
self.split_dim = split_dim
|
|
|
|
|
self.concat_dim = concat_dim
|
|
|
|
|
self.add_prim_attr('group', get_group(group))
|
|
|
|
|
self.add_prim_attr('group', _get_group(group))
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
|
x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count
|
|
|
|
|