|
|
|
|
@ -23,7 +23,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
|
|
|
|
|
from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore import context
|
|
|
|
|
|
|
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
|
|
|
|
|
|
|
|
class Net(nn.Cell):
|
|
|
|
|
"""Net definition"""
|
|
|
|
|
@ -64,6 +64,7 @@ def test_AdamWeightDecay():
|
|
|
|
|
net_with_loss = WithLossCell(net, loss)
|
|
|
|
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
|
|
|
|
_executor.compile(train_network, inputs, label)
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_lamb_compile():
|
|
|
|
|
@ -79,7 +80,24 @@ def test_lamb_compile():
|
|
|
|
|
net_with_loss = WithLossCell(net, loss)
|
|
|
|
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
|
|
|
|
_executor.compile(train_network, inputs, label)
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_lamb_split_fusion():
|
|
|
|
|
""" test_Lamb_split_fusion """
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
|
|
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([2, 4, 6, 8])
|
|
|
|
|
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
|
|
|
|
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
|
|
|
|
net = Net()
|
|
|
|
|
net.set_train()
|
|
|
|
|
loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
|
|
|
optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
|
|
|
|
|
|
|
|
|
|
net_with_loss = WithLossCell(net, loss)
|
|
|
|
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
|
|
|
|
_executor.compile(train_network, inputs, label)
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
|
|
|
|
|
def test_edge_case():
|
|
|
|
|
""" test_edge_case """
|
|
|
|
|
@ -93,3 +111,4 @@ def test_edge_case():
|
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
|
context.set_auto_parallel_context(device_num=16)
|
|
|
|
|
Lamb(net.trainable_params(), learning_rate=0.1)
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
|