|
|
|
@ -21,7 +21,7 @@ import argparse
|
|
|
|
|
from mindspore import context, Tensor, ParameterTuple
|
|
|
|
|
from mindspore.context import ParallelMode
|
|
|
|
|
from mindspore.communication.management import init, get_rank, get_group_size
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
|
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
from mindspore.nn.optim import Adam
|
|
|
|
|
from mindspore.nn import TrainOneStepCell
|
|
|
|
@ -89,7 +89,7 @@ if __name__ == '__main__':
|
|
|
|
|
print('Successfully loading the pre-trained model')
|
|
|
|
|
|
|
|
|
|
model = Model(train_net)
|
|
|
|
|
callback_list = [LossMonitor()]
|
|
|
|
|
callback_list = [TimeMonitor(steps_size), LossMonitor()]
|
|
|
|
|
|
|
|
|
|
if args.is_distributed:
|
|
|
|
|
config.CheckpointConfig.ckpt_file_name_prefix = config.CheckpointConfig.ckpt_file_name_prefix + str(get_rank())
|
|
|
|
|