diff --git a/example/resnet101_imagenet2012/README.md b/example/resnet101_imagenet2012/README.md index d5729b70db..852326c9d6 100644 --- a/example/resnet101_imagenet2012/README.md +++ b/example/resnet101_imagenet2012/README.md @@ -51,8 +51,8 @@ Parameters for both training and evaluating can be set in config.py. "image_height": 224, # image height "image_width": 224, # image width "save_checkpoint": True, # whether save checkpoint or not -"save_checkpoint_steps": 500, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step -"keep_checkpoint_max": 40, # only keep the last keep_checkpoint_max checkpoint +"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch +"keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint "save_checkpoint_path": "./", # path to save checkpoint relative to the executed path "warmup_epochs": 0, # number of warmup epoch "lr_decay_mode": "cosine" # decay mode for generating learning rate diff --git a/example/resnet101_imagenet2012/config.py b/example/resnet101_imagenet2012/config.py index ca58f24da3..0b9f16b504 100755 --- a/example/resnet101_imagenet2012/config.py +++ b/example/resnet101_imagenet2012/config.py @@ -28,8 +28,8 @@ config = ed({ "image_height": 224, "image_width": 224, "save_checkpoint": True, - "save_checkpoint_steps": 500, - "keep_checkpoint_max": 40, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 10, "save_checkpoint_path": "./", "warmup_epochs": 0, "lr_decay_mode": "cosine", diff --git a/example/resnet101_imagenet2012/eval.py b/example/resnet101_imagenet2012/eval.py index 979c6ca949..bdf6e89ca8 100755 --- a/example/resnet101_imagenet2012/eval.py +++ b/example/resnet101_imagenet2012/eval.py @@ -54,7 +54,7 @@ if __name__ == '__main__': if not args_opt.do_eval and args_opt.run_distribute: context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, parameter_broadcast=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([140]) + auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313]) init() epoch_size = config.epoch_size diff --git a/example/resnet101_imagenet2012/train.py b/example/resnet101_imagenet2012/train.py index c2de3e8d98..ca74262890 100755 --- a/example/resnet101_imagenet2012/train.py +++ b/example/resnet101_imagenet2012/train.py @@ -59,7 +59,7 @@ if __name__ == '__main__': if not args_opt.do_eval and args_opt.run_distribute: context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, parameter_broadcast=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([140]) + auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313]) init() epoch_size = config.epoch_size @@ -91,7 +91,7 @@ if __name__ == '__main__': loss_cb = LossMonitor() cb = [time_cb, loss_cb] if config.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size, keep_checkpoint_max=config.keep_checkpoint_max) ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck) cb += [ckpt_cb]