add reume for training

fixing bug

bug fixed

updates

pylint-fixed

asser-msg

assert
pull/7087/head
minara 4 years ago
parent 03f0e64af9
commit 5859204046

@ -80,7 +80,17 @@ if __name__ == '__main__':
lr = get_lr(lr_init=cfg.lr_init, lr_decay_rate=cfg.lr_decay_rate,
num_epoch_per_decay=cfg.num_epoch_per_decay, total_epochs=cfg.epoch_size,
steps_per_epoch=batches_per_epoch, is_stair=True)
lr = Tensor(lr)
if args_opt.resume:
name_dir = os.path.basename(args_opt.resume)
name, ext = name_dir.split(".")
split_result = name.split("_")
resume = split_result[-2].split("-")
resume_epoch = int(resume[-1])
step_num_in_epoch = int(split_result[-1])
assert step_num_in_epoch == ds_train.get_dataset_size()\
, "This script only supports resuming at the end of epoch"
lr = lr[(ds_train.get_dataset_size() * (resume_epoch - 1) + step_num_in_epoch):]
lr = Tensor(lr, mstype.float32)
# optimizer
decayed_params = []
@ -108,8 +118,14 @@ if __name__ == '__main__':
if args_opt.is_distributed & cfg.is_save_on_master:
if cfg.rank == 0:
callbacks.append(ckpoint_cb)
if args_opt.resume:
model.train(cfg.epoch_size - resume_epoch, dataset, callbacks=callbacks, dataset_sink_mode=True)
else:
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
else:
callbacks.append(ckpoint_cb)
if args_opt.resume:
model.train(cfg.epoch_size - resume_epoch, dataset, callbacks=callbacks, dataset_sink_mode=True)
else:
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
print("train success")

Loading…
Cancel
Save