|
|
|
@ -21,7 +21,6 @@ from mindspore import context
|
|
|
|
|
from mindspore.common.api import _executor
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops.operations.comm_ops import _VirtualDataset
|
|
|
|
|
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
@ -33,7 +32,6 @@ grad_all = C.GradOperation(get_all=True)
|
|
|
|
|
class Net(nn.Cell):
|
|
|
|
|
def __init__(self, strategy1, strategy2, num_segments):
|
|
|
|
|
super(Net, self).__init__()
|
|
|
|
|
self.virtual_dataset = _VirtualDataset()
|
|
|
|
|
self.merge_op = P.UnsortedSegmentSum().shard((strategy1, strategy2))
|
|
|
|
|
self.num_segments = num_segments
|
|
|
|
|
|
|
|
|
@ -54,8 +52,8 @@ class GradWrap(nn.Cell):
|
|
|
|
|
class NetWithLoss(nn.Cell):
|
|
|
|
|
def __init__(self, network):
|
|
|
|
|
super(NetWithLoss, self).__init__()
|
|
|
|
|
self.loss = VirtualLoss()
|
|
|
|
|
self.network = network
|
|
|
|
|
self.loss = VirtualLoss()
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y):
|
|
|
|
|
predict = self.network(x, y)
|
|
|
|
@ -63,13 +61,13 @@ class NetWithLoss(nn.Cell):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compile_graph(x, y, segments, strategy1, strategy2, auto=False):
|
|
|
|
|
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
|
|
|
|
|
net.set_auto_parallel()
|
|
|
|
|
net.set_train()
|
|
|
|
|
if auto:
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
|
|
|
|
else:
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
|
|
|
|
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
|
|
|
|
|
net.set_auto_parallel()
|
|
|
|
|
net.set_train()
|
|
|
|
|
_executor.compile(net, x, y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -151,3 +149,13 @@ def test_unsortedsegmentsum_model_parallel_index_vector_slice_3d():
|
|
|
|
|
strategy1 = (2, 1, 2)
|
|
|
|
|
strategy2 = (2, 1)
|
|
|
|
|
compile_graph(x, y, num_segments, strategy1, strategy2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_unsortedsegmentsum_model_parallel_repeat_caculate():
|
|
|
|
|
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
|
|
|
|
x = Tensor(np.ones((4, 4, 8)), ms.float32)
|
|
|
|
|
y = Tensor(np.ones((4, 4)), ms.int32)
|
|
|
|
|
num_segments = 16
|
|
|
|
|
strategy1 = (1, 1, 1)
|
|
|
|
|
strategy2 = (1, 1)
|
|
|
|
|
compile_graph(x, y, num_segments, strategy1, strategy2)
|
|
|
|
|