|
|
|
@ -14,7 +14,7 @@
|
|
|
|
|
|
|
|
|
|
""" test Communicate """
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp
|
|
|
|
|
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter
|
|
|
|
|
from mindspore.ops.operations.comm_ops import Broadcast
|
|
|
|
|
from mindspore.communication.management import HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, GlobalComm, init
|
|
|
|
|
from mindspore.communication._comm_helper import Backend
|
|
|
|
@ -78,6 +78,19 @@ class AllGatherNet(nn.Cell):
|
|
|
|
|
x = self.allgather(x)
|
|
|
|
|
return self.relu(x)
|
|
|
|
|
|
|
|
|
|
class ReduceScatterNet(nn.Cell):
|
|
|
|
|
"""ReduceScatterNet definition"""
|
|
|
|
|
def __init__(self, input_channel, out_channel, op):
|
|
|
|
|
super(ReduceScatterNet, self).__init__()
|
|
|
|
|
self.dense = Dense(input_channel, out_channel)
|
|
|
|
|
self.reducescatter = ReduceScatter(op)
|
|
|
|
|
self.relu = ReLU()
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
x = self.dense(x)
|
|
|
|
|
x = self.reducescatter(x)
|
|
|
|
|
return self.relu(x)
|
|
|
|
|
|
|
|
|
|
class AlltoAllNet(nn.Cell):
|
|
|
|
|
"""AlltoAllNet definition"""
|
|
|
|
|
def __init__(self, input_channel, out_channel):
|
|
|
|
@ -126,6 +139,25 @@ def test_allgather():
|
|
|
|
|
network = TrainOneStepCell(network, optimizer)
|
|
|
|
|
_executor.compile(network, input_tensor, label_tensor)
|
|
|
|
|
|
|
|
|
|
def run_reducescatter(op):
|
|
|
|
|
"""run_reducescatter"""
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|
input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
|
|
|
|
|
label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32))
|
|
|
|
|
network = ReduceScatterNet(2, 1, op)
|
|
|
|
|
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
|
|
|
optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
|
|
|
|
|
learning_rate=0.1,
|
|
|
|
|
momentum=0.9)
|
|
|
|
|
network = WithLossCell(network, loss_fn)
|
|
|
|
|
network = TrainOneStepCell(network, optimizer)
|
|
|
|
|
_executor.compile(network, input_tensor, label_tensor)
|
|
|
|
|
|
|
|
|
|
def test_reducescatter():
|
|
|
|
|
"""test_reducescatter"""
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|
run_reducescatter(ReduceOp.SUM)
|
|
|
|
|
|
|
|
|
|
def test_broadcast():
|
|
|
|
|
"""test_broadcast"""
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|