|
|
|
@ -59,7 +59,7 @@ if __name__ == '__main__':
|
|
|
|
|
if not args_opt.do_eval and args_opt.run_distribute:
|
|
|
|
|
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
|
mirror_mean=True, parameter_broadcast=True)
|
|
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
|
|
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313])
|
|
|
|
|
init()
|
|
|
|
|
|
|
|
|
|
epoch_size = config.epoch_size
|
|
|
|
@ -91,7 +91,7 @@ if __name__ == '__main__':
|
|
|
|
|
loss_cb = LossMonitor()
|
|
|
|
|
cb = [time_cb, loss_cb]
|
|
|
|
|
if config.save_checkpoint:
|
|
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
|
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size,
|
|
|
|
|
keep_checkpoint_max=config.keep_checkpoint_max)
|
|
|
|
|
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck)
|
|
|
|
|
cb += [ckpt_cb]
|
|
|
|
|