|
|
|
@ -37,8 +37,7 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|
|
|
|
dataset
|
|
|
|
|
"""
|
|
|
|
|
if target == "Ascend":
|
|
|
|
|
device_num = int(os.getenv("DEVICE_NUM"))
|
|
|
|
|
rank_id = int(os.getenv("RANK_ID"))
|
|
|
|
|
device_num, rank_id = _get_rank_info()
|
|
|
|
|
else:
|
|
|
|
|
init("nccl")
|
|
|
|
|
rank_id = get_rank()
|
|
|
|
@ -93,8 +92,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|
|
|
|
dataset
|
|
|
|
|
"""
|
|
|
|
|
if target == "Ascend":
|
|
|
|
|
device_num = int(os.getenv("DEVICE_NUM"))
|
|
|
|
|
rank_id = int(os.getenv("RANK_ID"))
|
|
|
|
|
device_num, rank_id = _get_rank_info()
|
|
|
|
|
else:
|
|
|
|
|
init("nccl")
|
|
|
|
|
rank_id = get_rank()
|
|
|
|
@ -153,8 +151,7 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|
|
|
|
Returns:
|
|
|
|
|
dataset
|
|
|
|
|
"""
|
|
|
|
|
device_num = int(os.getenv("RANK_SIZE"))
|
|
|
|
|
rank_id = int(os.getenv("RANK_ID"))
|
|
|
|
|
device_num, rank_id = _get_rank_info()
|
|
|
|
|
|
|
|
|
|
if device_num == 1:
|
|
|
|
|
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
|
|
|
|
@ -203,3 +200,19 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|
|
|
|
ds = ds.repeat(repeat_num)
|
|
|
|
|
|
|
|
|
|
return ds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_rank_info():
|
|
|
|
|
"""
|
|
|
|
|
get rank size and rank id
|
|
|
|
|
"""
|
|
|
|
|
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
|
|
|
|
|
|
|
|
|
if rank_size > 1:
|
|
|
|
|
rank_size = get_group_size()
|
|
|
|
|
rank_id = get_rank()
|
|
|
|
|
else:
|
|
|
|
|
rank_size = 1
|
|
|
|
|
rank_id = 0
|
|
|
|
|
|
|
|
|
|
return rank_size, rank_id
|
|
|
|
|