|
|
|
@ -66,15 +66,30 @@ class Net2(nn.Cell):
|
|
|
|
|
return x - y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None):
|
|
|
|
|
class Net3(nn.Cell):
|
|
|
|
|
"""Net definition"""
|
|
|
|
|
def __init__(self, strategy1, strategy2):
|
|
|
|
|
super(Net3, self).__init__()
|
|
|
|
|
self.fc1 = P.MatMul().shard(strategy=strategy1)
|
|
|
|
|
self.fc2 = P.MatMul().shard(strategy=strategy2)
|
|
|
|
|
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
|
|
|
|
|
self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False)
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y):
|
|
|
|
|
x = self.fc1(x, self.p1)
|
|
|
|
|
x = self.fc2(x, self.p2)
|
|
|
|
|
return x - y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None):
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True)
|
|
|
|
|
inputs = Tensor(np.ones([32, 48]).astype(np.float32))
|
|
|
|
|
label = Tensor(np.zeros([32, 16]).astype(np.float32))
|
|
|
|
|
net = Net2(strategy1, strategy2)
|
|
|
|
|
net = net(strategy1, strategy2)
|
|
|
|
|
net = _VirtualDatasetCell(net)
|
|
|
|
|
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
|
|
|
train_network = TrainOneStepCell(net, optimizer)
|
|
|
|
|
train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4)
|
|
|
|
|
train_network.set_auto_parallel()
|
|
|
|
|
train_network.set_train()
|
|
|
|
|
_executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True)
|
|
|
|
@ -83,18 +98,18 @@ def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_auto_parallel_momentum_1():
|
|
|
|
|
auto_parallel_compile_net("auto_parallel", 8)
|
|
|
|
|
auto_parallel_compile_net("auto_parallel", 8, Net2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_auto_parallel_momentum_2():
|
|
|
|
|
# data parallel case
|
|
|
|
|
auto_parallel_compile_net("auto_parallel", 8, ((8, 1), (1, 1)), ((8, 1), (1, 1)))
|
|
|
|
|
auto_parallel_compile_net("auto_parallel", 8, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_auto_parallel_momentum_3():
|
|
|
|
|
# hybrid parallel case
|
|
|
|
|
# weight1 could not be shard and weight2 is repeated
|
|
|
|
|
train_network = auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
|
|
|
|
|
train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
|
|
|
|
|
param_dict = train_network.parameter_layout_dict
|
|
|
|
|
# validate opt_shard_group
|
|
|
|
|
assert not param_dict["weight1"][5]
|
|
|
|
@ -104,7 +119,16 @@ def test_auto_parallel_momentum_3():
|
|
|
|
|
def test_auto_parallel_momentum_4():
|
|
|
|
|
# hybrid parallel cases
|
|
|
|
|
# devices are repeatedly used
|
|
|
|
|
auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 4), (4, 1)), ((4, 4), (4, 2)))
|
|
|
|
|
auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 4), (4, 1)), ((4, 4), (4, 2)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_auto_parallel_momentum_5():
|
|
|
|
|
# test parallel optimizer filter
|
|
|
|
|
train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net3, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
|
|
|
|
|
param_dict = train_network.parameter_layout_dict
|
|
|
|
|
# validate opt_shard_group
|
|
|
|
|
assert not param_dict["weight1"][5]
|
|
|
|
|
assert not param_dict["weight2"][5]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_AdamWeightDecay():
|
|
|
|
|