|
|
|
@ -169,6 +169,72 @@ class AllGather(PrimitiveWithInfer):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HostAllGather(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
Gathers tensors from the specified communication group on host.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Tensor must have the same shape and format in all processes participating in the collective.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If group is not a list nor tuple, or elements of group are not int.
|
|
|
|
|
ValueError: If the local rank id of the calling process not in group,
|
|
|
|
|
or rank_id from group not in [0, 7].
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor. If the number of devices in the group is N,
|
|
|
|
|
then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> from mindspore.communication import init
|
|
|
|
|
>>> import mindspore.ops.operations as P
|
|
|
|
|
>>> init('nccl')
|
|
|
|
|
>>> class Net(nn.Cell):
|
|
|
|
|
>>> def __init__(self):
|
|
|
|
|
>>> super(Net, self).__init__()
|
|
|
|
|
>>> self.hostallgather = P.HostAllGather(group=(0, 1, 2, 3))
|
|
|
|
|
>>>
|
|
|
|
|
>>> def construct(self, x):
|
|
|
|
|
>>> return self.hostallgather(x)
|
|
|
|
|
>>>
|
|
|
|
|
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
>>> output = net(input_)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, group=None):
|
|
|
|
|
if group is None:
|
|
|
|
|
raise ValueError(f"For '{self.name}' group must be set.")
|
|
|
|
|
validator.check_value_type('group', group, (tuple, list), self.name)
|
|
|
|
|
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
|
|
|
|
|
for r in group:
|
|
|
|
|
validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name)
|
|
|
|
|
validator.check_value_type("rank_id", r, (int,), self.name)
|
|
|
|
|
self.group_size = len(group)
|
|
|
|
|
self.rank = get_rank()
|
|
|
|
|
validator.check('rank', self.rank, 'group', self.group, Rel.IN, self.name)
|
|
|
|
|
self.add_prim_attr('group', group)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
|
validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name)
|
|
|
|
|
x_shape[0] = x_shape[0] * self.group_size
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
def __call__(self, tensor):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReduceScatter(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
Reduces and scatters tensors from the specified communication group.
|
|
|
|
@ -226,6 +292,68 @@ class ReduceScatter(PrimitiveWithInfer):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HostReduceScatter(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
Reduces and scatters tensors from the specified communication group on host.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Tensor must have the same shape and format in all processes participating in the collective.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
op (str): Specifies an operation used for element-wise reductions,
|
|
|
|
|
like sum, max, avg. Default: ReduceOp.SUM.
|
|
|
|
|
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
|
|
|
|
|
|
|
|
|
|
Raise:
|
|
|
|
|
TypeError: If op is not a string and group is not a list nor tuple,
|
|
|
|
|
or elements of group are not int.
|
|
|
|
|
ValueError: If the first dimension of input can not be divided by rank size,
|
|
|
|
|
or group is not set, or rank_id not in [1, 7].
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> from mindspore.communication import init
|
|
|
|
|
>>> import mindspore.ops.operations as P
|
|
|
|
|
>>> init('nccl')
|
|
|
|
|
>>> class Net(nn.Cell):
|
|
|
|
|
>>> def __init__(self):
|
|
|
|
|
>>> super(Net, self).__init__()
|
|
|
|
|
>>> self.hostreducescatter = P.HostReduceScatter(ReduceOp.SUM, group=[0, 1, 2, 3])
|
|
|
|
|
>>>
|
|
|
|
|
>>> def construct(self, x):
|
|
|
|
|
>>> return self.hostreducescatter(x)
|
|
|
|
|
>>>
|
|
|
|
|
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
>>> output = net(input_)
|
|
|
|
|
"""
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, op=ReduceOp.SUM, group=None):
|
|
|
|
|
if group is None:
|
|
|
|
|
raise ValueError(f"For '{self.name}' group must be set.")
|
|
|
|
|
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
|
|
|
|
|
validator.check_value_type('group', group, (tuple, list), self.name)
|
|
|
|
|
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
|
|
|
|
|
for r in group:
|
|
|
|
|
validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name)
|
|
|
|
|
validator.check_value_type("rank_id", r, (int,), self.name)
|
|
|
|
|
self.op = op
|
|
|
|
|
self.group_size = len(group)
|
|
|
|
|
self.add_prim_attr('group', group)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
|
if x_shape[0] % self.group_size != 0:
|
|
|
|
|
raise ValueError(f"For '{self.name}' the first dimension of x should be divided by group_size.")
|
|
|
|
|
x_shape[0] = int(x_shape[0]/self.group_size)
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype):
|
|
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
|
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
def __call__(self, tensor):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Broadcast(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
Broadcasts the tensor to the whole group.
|
|
|
|
|