|
|
|
@ -168,6 +168,17 @@ def create_node(node):
|
|
|
|
|
# Find a matching Dataset class and call the constructor with the corresponding args.
|
|
|
|
|
# When a new Dataset class is introduced, another if clause and parsing code needs to be added.
|
|
|
|
|
# Dataset Source Ops (in alphabetical order)
|
|
|
|
|
pyobj = create_dataset_node(pyclass, node, dataset_op)
|
|
|
|
|
if not pyobj:
|
|
|
|
|
# Dataset Ops (in alphabetical order)
|
|
|
|
|
pyobj = create_dataset_operation_node(node, dataset_op)
|
|
|
|
|
|
|
|
|
|
return pyobj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_dataset_node(pyclass, node, dataset_op):
|
|
|
|
|
"""Parse the key, value in the dataset node dictionary and instantiate the Python Dataset object"""
|
|
|
|
|
pyobj = None
|
|
|
|
|
if dataset_op == 'CelebADataset':
|
|
|
|
|
sampler = construct_sampler(node.get('sampler'))
|
|
|
|
|
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
|
|
|
|
@ -189,7 +200,7 @@ def create_node(node):
|
|
|
|
|
|
|
|
|
|
elif dataset_op == 'ClueDataset':
|
|
|
|
|
shuffle = to_shuffle_mode(node.get('shuffle'))
|
|
|
|
|
if shuffle is not None and isinstance(shuffle, str):
|
|
|
|
|
if isinstance(shuffle, str):
|
|
|
|
|
shuffle = de.Shuffle(shuffle)
|
|
|
|
|
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
|
|
|
|
|
pyobj = pyclass(node['dataset_files'], node.get('task'),
|
|
|
|
@ -205,7 +216,7 @@ def create_node(node):
|
|
|
|
|
|
|
|
|
|
elif dataset_op == 'CSVDataset':
|
|
|
|
|
shuffle = to_shuffle_mode(node.get('shuffle'))
|
|
|
|
|
if shuffle is not None and isinstance(shuffle, str):
|
|
|
|
|
if isinstance(shuffle, str):
|
|
|
|
|
shuffle = de.Shuffle(shuffle)
|
|
|
|
|
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
|
|
|
|
|
pyobj = pyclass(node['dataset_files'], node.get('field_delim'),
|
|
|
|
@ -237,7 +248,7 @@ def create_node(node):
|
|
|
|
|
|
|
|
|
|
elif dataset_op == 'TextFileDataset':
|
|
|
|
|
shuffle = to_shuffle_mode(node.get('shuffle'))
|
|
|
|
|
if shuffle is not None and isinstance(shuffle, str):
|
|
|
|
|
if isinstance(shuffle, str):
|
|
|
|
|
shuffle = de.Shuffle(shuffle)
|
|
|
|
|
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
|
|
|
|
|
pyobj = pyclass(node['dataset_files'], num_samples,
|
|
|
|
@ -246,7 +257,7 @@ def create_node(node):
|
|
|
|
|
|
|
|
|
|
elif dataset_op == 'TFRecordDataset':
|
|
|
|
|
shuffle = to_shuffle_mode(node.get('shuffle'))
|
|
|
|
|
if shuffle is not None and isinstance(shuffle, str):
|
|
|
|
|
if isinstance(shuffle, str):
|
|
|
|
|
shuffle = de.Shuffle(shuffle)
|
|
|
|
|
num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
|
|
|
|
|
pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('columns_list'),
|
|
|
|
@ -260,8 +271,13 @@ def create_node(node):
|
|
|
|
|
num_samples, node.get('num_parallel_workers'), node.get('shuffle'),
|
|
|
|
|
node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id'))
|
|
|
|
|
|
|
|
|
|
# Dataset Ops (in alphabetical order)
|
|
|
|
|
elif dataset_op == 'Batch':
|
|
|
|
|
return pyobj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_dataset_operation_node(node, dataset_op):
|
|
|
|
|
"""Parse the key, value in the dataset operation node dictionary and instantiate the Python Dataset object"""
|
|
|
|
|
pyobj = None
|
|
|
|
|
if dataset_op == 'Batch':
|
|
|
|
|
pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder'))
|
|
|
|
|
|
|
|
|
|
elif dataset_op == 'Map':
|
|
|
|
@ -292,7 +308,7 @@ def create_node(node):
|
|
|
|
|
pyobj = de.Dataset().to_device(node.get('send_epoch_end'), node.get('create_data_info_queue'))
|
|
|
|
|
|
|
|
|
|
elif dataset_op == 'Zip':
|
|
|
|
|
# Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller.
|
|
|
|
|
# Create ZipDataset instance, giving dummy input dataset that will be overrode in the caller.
|
|
|
|
|
pyobj = de.ZipDataset((de.Dataset(), de.Dataset()))
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|