|
|
|
@ -142,16 +142,19 @@ class AllGather(PrimitiveWithInfer):
|
|
|
|
|
``Ascend`` ``GPU``
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> # This example should be run with two devices. Refer to the tutorial > Distirbuted Training on mindspore.cn.
|
|
|
|
|
>>> import numpy as np
|
|
|
|
|
>>> import mindspore.ops.operations as ops
|
|
|
|
|
>>> import mindspore.nn as nn
|
|
|
|
|
>>> from mindspore.communication import init
|
|
|
|
|
>>> from mindspore import Tensor
|
|
|
|
|
>>> from mindspore import Tensor, context
|
|
|
|
|
>>>
|
|
|
|
|
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|
>>> init()
|
|
|
|
|
... class Net(nn.Cell):
|
|
|
|
|
... def __init__(self):
|
|
|
|
|
... super(Net, self).__init__()
|
|
|
|
|
... self.allgather = ops.AllGather(group="nccl_world_group")
|
|
|
|
|
... self.allgather = ops.AllGather()
|
|
|
|
|
...
|
|
|
|
|
... def construct(self, x):
|
|
|
|
|
... return self.allgather(x)
|
|
|
|
@ -160,6 +163,10 @@ class AllGather(PrimitiveWithInfer):
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
>>> output = net(input_)
|
|
|
|
|
>>> print(output)
|
|
|
|
|
[[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
|
|
|
[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
|
|
|
[1. 1. 1. 1. 1. 1. 1. 1.]
|
|
|
|
|
[1. 1. 1. 1. 1. 1. 1. 1.]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
@ -255,16 +262,18 @@ class ReduceScatter(PrimitiveWithInfer):
|
|
|
|
|
ValueError: If the first dimension of the input cannot be divided by the rank size.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``GPU``
|
|
|
|
|
``Ascend`` ``GPU``
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> from mindspore import Tensor
|
|
|
|
|
>>> # This example should be run with two devices. Refer to the tutorial > Distirbuted Training on mindspore.cn.
|
|
|
|
|
>>> from mindspore import Tensor, context
|
|
|
|
|
>>> from mindspore.communication import init
|
|
|
|
|
>>> from mindspore.ops.operations.comm_ops import ReduceOp
|
|
|
|
|
>>> import mindspore.nn as nn
|
|
|
|
|
>>> import mindspore.ops.operations as ops
|
|
|
|
|
>>> import numpy as np
|
|
|
|
|
>>>
|
|
|
|
|
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|
>>> init()
|
|
|
|
|
>>> class Net(nn.Cell):
|
|
|
|
|
... def __init__(self):
|
|
|
|
@ -278,6 +287,10 @@ class ReduceScatter(PrimitiveWithInfer):
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
>>> output = net(input_)
|
|
|
|
|
>>> print(output)
|
|
|
|
|
[[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
|
|
|
[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
|
|
|
[2. 2. 2. 2. 2. 2. 2. 2.]
|
|
|
|
|
[2. 2. 2. 2. 2. 2. 2. 2.]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|