From 866b53d5a8314fc5f33d9d755d4f90e7f4aba46e Mon Sep 17 00:00:00 2001 From: caifubi Date: Mon, 8 Mar 2021 14:07:39 +0800 Subject: [PATCH] modify pynative gpu benchmark script --- .../cv/resnet/gpu_resnet_benchmark.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py b/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py index 90c4305a5d..7ef8b3497f 100644 --- a/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py +++ b/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py @@ -94,11 +94,19 @@ class MyTimeMonitor(Callback): def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU", dtype="fp16", device_num=1): + if args_opt.mode == "GRAPH": + ds_num_parallel_worker = 4 + map_num_parallel_worker = 8 + batch_num_parallel_worker = None + else: + ds_num_parallel_worker = 2 + map_num_parallel_worker = 3 + batch_num_parallel_worker = 2 ds.config.set_numa_enable(True) if device_num == 1: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True) + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=ds_num_parallel_worker, shuffle=True) else: - data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True, + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=ds_num_parallel_worker, shuffle=True, num_shards=device_num, shard_id=get_rank()) image_size = 224 mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] @@ -127,9 +135,9 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target=" ] if dtype == "fp32": trans.append(C.HWC2CHW()) - data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8) + data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=map_num_parallel_worker) # apply batch operations - data_set = data_set.batch(batch_size, drop_remainder=True) + data_set = data_set.batch(batch_size, drop_remainder=True, num_parallel_workers=batch_num_parallel_worker) # apply dataset repeat operation if repeat_num > 1: data_set = data_set.repeat(repeat_num) @@ -165,14 +173,16 @@ def train(): # init context if args_opt.mode == "GRAPH": mode = context.GRAPH_MODE + all_reduce_fusion_config = [85, 160] else: mode = context.PYNATIVE_MODE + all_reduce_fusion_config = [30, 90, 160] context.set_context(mode=mode, device_target=dev, save_graphs=False) if args_opt.run_distribute: init() device_num = get_group_size() context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - gradients_mean=True, all_reduce_fusion_config=[85, 160]) + gradients_mean=True, all_reduce_fusion_config=all_reduce_fusion_config) ckpt_save_dir = ckpt_save_dir + "ckpt_" + str(get_rank()) + "/" # create dataset