!8580 Add Operation AllSwap Primitive

From: @huangxinjing
Reviewed-by: 
Signed-off-by:
pull/8580/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 39352ca658

@ -217,6 +217,8 @@ AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const Primit
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBroadcast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplBroadcast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

@ -367,6 +367,45 @@ AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, co
return sparse_tensor->dense_shape(); return sparse_tensor->dense_shape();
} }
AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3);
auto tensor_in = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(tensor_in);
MS_EXCEPTION_IF_NULL(tensor_in->shape());
auto tensor_in_shape = tensor_in->shape()->shape();
auto send_size = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(send_size);
auto recv_size = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
MS_EXCEPTION_IF_NULL(recv_size);
// Get the content of the recv size
auto recv_size_value_ptr = recv_size->BuildValue();
MS_EXCEPTION_IF_NULL(recv_size_value_ptr);
auto recv_size_tensor = recv_size_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(recv_size_tensor);
auto data_pos = reinterpret_cast<int64_t *>(recv_size_tensor->data_c());
MS_EXCEPTION_IF_NULL(data_pos);
int64_t infer_max_size = 0;
for (int64_t i = 0; i < recv_size_tensor->DataSize(); ++i) {
infer_max_size += *(data_pos + i);
}
ShapeVector tensor_out_shape = {Shape::SHP_ANY, tensor_in_shape[1]};
ShapeVector min_shape = {1, tensor_in_shape[1]};
ShapeVector max_shape = {infer_max_size / tensor_in_shape[1], tensor_in_shape[1]};
auto tensor_out = std::make_shared<AbstractTensor>(tensor_in->element(),
std::make_shared<Shape>(tensor_out_shape, min_shape, max_shape));
AbstractTensorPtr ret = std::make_shared<AbstractTensor>(
tensor_out->element(), std::make_shared<Shape>(tensor_out_shape, min_shape, max_shape));
return ret;
}
AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();

@ -135,6 +135,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimAllReduce, {InferImplAllReduce, true}}, {prim::kPrimAllReduce, {InferImplAllReduce, true}},
{prim::kPrimBroadcast, {InferImplBroadcast, true}}, {prim::kPrimBroadcast, {InferImplBroadcast, true}},
{prim::kPrimAllGather, {InferImplAllGather, true}}, {prim::kPrimAllGather, {InferImplAllGather, true}},
{prim::kPrimAllSwap, {InferImplAllSwap, true}},
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, {prim::kPrimReduceScatter, {InferImplReduceScatter, true}},
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
{prim::kPrimCast, {InferImplCast, true}}, {prim::kPrimCast, {InferImplCast, true}},

@ -186,6 +186,7 @@ inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOper
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap");
inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast"); inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast");
inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather"); inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather");
inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter"); inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter");

@ -21,7 +21,7 @@ from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive, _GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive,
ReduceScatter, _HostReduceScatter, _VirtualDiv) ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
from .grad_base import bprop_getters from .grad_base import bprop_getters
@ -155,6 +155,21 @@ def get_bprop_reduce_scatter(self):
return bprop return bprop
@bprop_getters.register(AllSwap)
def get_bprop_allswap(self):
"""Generate bprop for AllSwap."""
all_swap_grad = AllSwap(self.group)
if self.instance_name:
instance_name = "grad" + self.instance_name
all_to_all_grad.set_prim_instance_name(instance_name)
def bprop(x, send_size, recv_size, out, dout):
dx = all_swap_grad(dout, recv_size, send_size)
return (dx,)
return bprop
@bprop_getters.register(_HostReduceScatter) @bprop_getters.register(_HostReduceScatter)
def get_bprop_host_reduce_scatter(self): def get_bprop_host_reduce_scatter(self):
"""Generate bprop for _HostReduceScatter""" """Generate bprop for _HostReduceScatter"""

@ -34,7 +34,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
Unique, GatherD, Identity, RepeatElements) Unique, GatherD, Identity, RepeatElements)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset, _MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice, Send, Receive, _VirtualDiv, _GetTensorSlice, Send, Receive,
_HostAllGather, _HostReduceScatter) _HostAllGather, _HostReduceScatter)
@ -295,6 +295,7 @@ __all__ = [
'UnsortedSegmentProd', 'UnsortedSegmentProd',
"AllGather", "AllGather",
"AllReduce", "AllReduce",
"AllSwap",
"ReduceScatter", "ReduceScatter",
"Broadcast", "Broadcast",
"ReduceOp", "ReduceOp",

@ -20,7 +20,7 @@ from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group
from ...common import dtype as mstype from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
class ReduceOp: class ReduceOp:
@ -518,6 +518,59 @@ class Broadcast(PrimitiveWithInfer):
return x_dtype return x_dtype
class AllSwap(PrimitiveWithCheck):
"""
AllSwap is a collective operation.
AllSwap sends data from the all processes to the all processes in the specified group. It has two phases:
- The scatter phase: On each process, the operand is split into the send size of blocks along the
0-th axis, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.
- The gather phase: Each process concatenates the received blocks along the 0-th axis.
Note:
The tensors must have the same format in all processes of the collection.
Args:
group (str): The communication group name.
Inputs:
tensor_in (tensor): A 2-D tensor. On each process, divide blocks into number of the send size.
send_size (tensor): A 1-D int64 tensor. The element is the send data size for each process.
recv_size (tensor): A 1-D int64 tensor. The element is the receive data size for each process.
Returns:
tensor_out (tensor): The result tensor.
Raises:
TypeError: If group is not a string.
"""
@prim_attr_register
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
"""Initialize AllSwap"""
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
self.add_prim_attr('group', _get_group(group))
def __check__(self, tensor_in, send_size, recv_size):
validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid("send_size", send_size['dtype'], [mstype.int64],
self.name)
validator.check_tensor_dtype_valid("recv_size", recv_size['dtype'], [mstype.int64],
self.name)
validator.check_equal_int(len(tensor_in['shape']), 2, "tensor_in", self.name)
validator.check_equal_int(len(send_size['shape']), 1, "send_size", self.name)
validator.check_equal_int(len(recv_size['shape']), 1, "recv_size", self.name)
out_shape = [-1] + [tensor_in['shape'][1]]
out = {'shape': out_shape,
'dtype': tensor_in['dtype'],
'value': None}
return out
class _AlltoAll(PrimitiveWithInfer): class _AlltoAll(PrimitiveWithInfer):
""" """
AlltoAll is a collective operation. AlltoAll is a collective operation.

@ -26,7 +26,9 @@ from mindspore.nn import Momentum
from mindspore.nn import ReLU from mindspore.nn import ReLU
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter
from mindspore.ops.operations.comm_ops import Broadcast from mindspore.ops.operations.comm_ops import Broadcast, AllSwap
from mindspore.ops.operations.math_ops import ReduceSum
import mindspore
# pylint: disable=W0212 # pylint: disable=W0212
# W0212: protected-access # W0212: protected-access
@ -117,6 +119,25 @@ class AlltoAllNet(nn.Cell):
return self.relu(x) return self.relu(x)
class AllSwapNet(nn.Cell):
"""AlltoAllNet definition"""
def __init__(self, batch_size, input_channel, out_channel):
super(AllSwapNet, self).__init__()
self.dense = Dense(input_channel, out_channel)
self.allswap = AllSwap()
self.relu = ReLU()
self.reduce = ReduceSum()
part_slice = batch_size / 2
self.send_size = Tensor([0, part_slice*out_channel, part_slice*out_channel], mindspore.int64)
self.recv_size = Tensor([part_slice*out_channel, part_slice*out_channel, 0], mindspore.int64)
def construct(self, x):
x = self.dense(x)
x = self.allswap(x, self.send_size, self.recv_size)
x = self.relu(x)
return x
def run_allreduce(op): def run_allreduce(op):
"""run_allreduce""" """run_allreduce"""
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
@ -154,6 +175,13 @@ def test_allgather():
network = TrainOneStepCell(network, optimizer) network = TrainOneStepCell(network, optimizer)
_executor.compile(network, input_tensor, label_tensor) _executor.compile(network, input_tensor, label_tensor)
def test_allswap():
"""run_allswap"""
context.set_context(mode=context.GRAPH_MODE)
input_tensor = Tensor(np.ones((100, 20)), dtype=mindspore.float32)
network = AllSwapNet(100, 20, 20)
_executor.compile(network, input_tensor)
def run_reducescatter(op): def run_reducescatter(op):
"""run_reducescatter""" """run_reducescatter"""

Loading…
Cancel
Save