|
|
|
@ -106,6 +106,7 @@ if __name__ == '__main__':
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
|
|
|
|
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
|
|
|
|
|
|
|
|
|
rank = 0
|
|
|
|
|
if device_target == "Ascend":
|
|
|
|
|
if args_opt.device_id is not None:
|
|
|
|
|
context.set_context(device_id=args_opt.device_id)
|
|
|
|
@ -117,6 +118,7 @@ if __name__ == '__main__':
|
|
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
|
gradients_mean=True)
|
|
|
|
|
init()
|
|
|
|
|
rank = get_rank()
|
|
|
|
|
elif device_target == "GPU":
|
|
|
|
|
init()
|
|
|
|
|
|
|
|
|
@ -124,6 +126,7 @@ if __name__ == '__main__':
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
|
gradients_mean=True)
|
|
|
|
|
rank = get_rank()
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Unsupported platform.")
|
|
|
|
|
|
|
|
|
@ -200,14 +203,13 @@ if __name__ == '__main__':
|
|
|
|
|
if device_target == "Ascend":
|
|
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
|
|
|
|
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager)
|
|
|
|
|
ckpt_save_dir = "./"
|
|
|
|
|
else: # GPU
|
|
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
|
|
|
|
amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=loss_scale_manager)
|
|
|
|
|
ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/"
|
|
|
|
|
|
|
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
|
|
|
|
time_cb = TimeMonitor(data_size=batch_num)
|
|
|
|
|
ckpt_save_dir = "./ckpt_" + str(rank) + "/"
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + args_opt.dataset_name, directory=ckpt_save_dir,
|
|
|
|
|
config=config_ck)
|
|
|
|
|
loss_cb = LossMonitor()
|
|
|
|
|