|
|
@ -44,9 +44,9 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|
|
|
device_num = get_group_size()
|
|
|
|
device_num = get_group_size()
|
|
|
|
|
|
|
|
|
|
|
|
if device_num == 1:
|
|
|
|
if device_num == 1:
|
|
|
|
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
|
|
|
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=12, shuffle=True)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
|
|
|
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=12, shuffle=True,
|
|
|
|
num_shards=device_num, shard_id=rank_id)
|
|
|
|
num_shards=device_num, shard_id=rank_id)
|
|
|
|
|
|
|
|
|
|
|
|
# define map operations
|
|
|
|
# define map operations
|
|
|
@ -66,8 +66,8 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|
|
|
|
|
|
|
|
|
|
|
type_cast_op = C2.TypeCast(mstype.int32)
|
|
|
|
type_cast_op = C2.TypeCast(mstype.int32)
|
|
|
|
|
|
|
|
|
|
|
|
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
|
|
|
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12)
|
|
|
|
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
|
|
|
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12)
|
|
|
|
|
|
|
|
|
|
|
|
# apply batch operations
|
|
|
|
# apply batch operations
|
|
|
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
|
|
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
|
|
@ -99,9 +99,9 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|
|
|
device_num = get_group_size()
|
|
|
|
device_num = get_group_size()
|
|
|
|
|
|
|
|
|
|
|
|
if device_num == 1:
|
|
|
|
if device_num == 1:
|
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
|
|
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True,
|
|
|
|
num_shards=device_num, shard_id=rank_id)
|
|
|
|
num_shards=device_num, shard_id=rank_id)
|
|
|
|
|
|
|
|
|
|
|
|
image_size = 224
|
|
|
|
image_size = 224
|
|
|
@ -127,8 +127,8 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|
|
|
|
|
|
|
|
|
|
|
type_cast_op = C2.TypeCast(mstype.int32)
|
|
|
|
type_cast_op = C2.TypeCast(mstype.int32)
|
|
|
|
|
|
|
|
|
|
|
|
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
|
|
|
|
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12)
|
|
|
|
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
|
|
|
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12)
|
|
|
|
|
|
|
|
|
|
|
|
# apply batch operations
|
|
|
|
# apply batch operations
|
|
|
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
|
|
|
data_set = data_set.batch(batch_size, drop_remainder=True)
|
|
|
|