|
|
@ -20,7 +20,7 @@ from ..._checkparam import Validator as validator
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group
|
|
|
|
from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
|
|
|
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReduceOp:
|
|
|
|
class ReduceOp:
|
|
|
@ -518,6 +518,59 @@ class Broadcast(PrimitiveWithInfer):
|
|
|
|
return x_dtype
|
|
|
|
return x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AllSwap(PrimitiveWithCheck):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
AllSwap is a collective operation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AllSwap sends data from the all processes to the all processes in the specified group. It has two phases:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- The scatter phase: On each process, the operand is split into the send size of blocks along the
|
|
|
|
|
|
|
|
0-th axis, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.
|
|
|
|
|
|
|
|
- The gather phase: Each process concatenates the received blocks along the 0-th axis.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
|
|
|
The tensors must have the same format in all processes of the collection.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
group (str): The communication group name.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
|
|
|
tensor_in (tensor): A 2-D tensor. On each process, divide blocks into number of the send size.
|
|
|
|
|
|
|
|
send_size (tensor): A 1-D int64 tensor. The element is the send data size for each process.
|
|
|
|
|
|
|
|
recv_size (tensor): A 1-D int64 tensor. The element is the receive data size for each process.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
tensor_out (tensor): The result tensor.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
|
|
TypeError: If group is not a string.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
|
|
|
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
|
|
|
"""Initialize AllSwap"""
|
|
|
|
|
|
|
|
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
|
|
|
|
|
|
|
self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
|
|
|
|
|
|
|
|
self.add_prim_attr('group', _get_group(group))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __check__(self, tensor_in, send_size, recv_size):
|
|
|
|
|
|
|
|
validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name)
|
|
|
|
|
|
|
|
validator.check_tensor_dtype_valid("send_size", send_size['dtype'], [mstype.int64],
|
|
|
|
|
|
|
|
self.name)
|
|
|
|
|
|
|
|
validator.check_tensor_dtype_valid("recv_size", recv_size['dtype'], [mstype.int64],
|
|
|
|
|
|
|
|
self.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
validator.check_equal_int(len(tensor_in['shape']), 2, "tensor_in", self.name)
|
|
|
|
|
|
|
|
validator.check_equal_int(len(send_size['shape']), 1, "send_size", self.name)
|
|
|
|
|
|
|
|
validator.check_equal_int(len(recv_size['shape']), 1, "recv_size", self.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_shape = [-1] + [tensor_in['shape'][1]]
|
|
|
|
|
|
|
|
out = {'shape': out_shape,
|
|
|
|
|
|
|
|
'dtype': tensor_in['dtype'],
|
|
|
|
|
|
|
|
'value': None}
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _AlltoAll(PrimitiveWithInfer):
|
|
|
|
class _AlltoAll(PrimitiveWithInfer):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
AlltoAll is a collective operation.
|
|
|
|
AlltoAll is a collective operation.
|
|
|
|