|
|
|
@ -21,29 +21,38 @@ import mindspore.dataset.vision.c_transforms as C
|
|
|
|
|
from .distributed_sampler import DistributedSampler
|
|
|
|
|
from .datasets import UnalignedDataset, ImageFolderDataset
|
|
|
|
|
|
|
|
|
|
def create_dataset(args, shuffle=True, max_dataset_size=float("inf")):
|
|
|
|
|
def create_dataset(args):
|
|
|
|
|
"""Create dataset"""
|
|
|
|
|
dataroot = args.dataroot
|
|
|
|
|
phase = args.phase
|
|
|
|
|
batch_size = args.batch_size
|
|
|
|
|
device_num = args.device_num
|
|
|
|
|
rank = args.rank
|
|
|
|
|
shuffle = args.use_random
|
|
|
|
|
max_dataset_size = args.max_dataset_size
|
|
|
|
|
cores = multiprocessing.cpu_count()
|
|
|
|
|
num_parallel_workers = min(8, int(cores / device_num))
|
|
|
|
|
image_size = args.image_size
|
|
|
|
|
mean = [0.5 * 255] * 3
|
|
|
|
|
std = [0.5 * 255] * 3
|
|
|
|
|
if phase == "train":
|
|
|
|
|
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size)
|
|
|
|
|
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size, use_random=args.use_random)
|
|
|
|
|
distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
|
|
|
|
|
ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
|
|
|
|
|
sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
|
|
|
|
|
trans = [
|
|
|
|
|
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
|
|
|
|
|
C.RandomHorizontalFlip(prob=0.5),
|
|
|
|
|
C.Normalize(mean=mean, std=std),
|
|
|
|
|
C.HWC2CHW()
|
|
|
|
|
]
|
|
|
|
|
if args.use_random:
|
|
|
|
|
trans = [
|
|
|
|
|
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
|
|
|
|
|
C.RandomHorizontalFlip(prob=0.5),
|
|
|
|
|
C.Normalize(mean=mean, std=std),
|
|
|
|
|
C.HWC2CHW()
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
trans = [
|
|
|
|
|
C.Resize((image_size, image_size)),
|
|
|
|
|
C.Normalize(mean=mean, std=std),
|
|
|
|
|
C.HWC2CHW()
|
|
|
|
|
]
|
|
|
|
|
ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
|
|
|
|
|
ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
|
|
|
|
|
ds = ds.batch(batch_size, drop_remainder=True)
|
|
|
|
|