|
|
|
@ -127,9 +127,12 @@ def serialize_operations(node_repr, key, val):
|
|
|
|
|
|
|
|
|
|
def serialize_sampler(node_repr, val):
|
|
|
|
|
"""Serialize sampler object to dictionary."""
|
|
|
|
|
node_repr['sampler'] = val.__dict__
|
|
|
|
|
node_repr['sampler']['sampler_module'] = type(val).__module__
|
|
|
|
|
node_repr['sampler']['sampler_name'] = type(val).__name__
|
|
|
|
|
if val is None:
|
|
|
|
|
node_repr['sampler'] = None
|
|
|
|
|
else:
|
|
|
|
|
node_repr['sampler'] = val.__dict__
|
|
|
|
|
node_repr['sampler']['sampler_module'] = type(val).__module__
|
|
|
|
|
node_repr['sampler']['sampler_name'] = type(val).__name__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def traverse(node):
|
|
|
|
@ -253,9 +256,10 @@ def create_node(node):
|
|
|
|
|
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
|
|
|
|
|
|
|
|
|
|
elif dataset_op == 'MindDataset':
|
|
|
|
|
pyobj = pyclass(node['dataset_file'], node.get('column_list'),
|
|
|
|
|
sampler = construct_sampler(node.get('sampler'))
|
|
|
|
|
pyobj = pyclass(node['dataset_file'], node.get('columns_list'),
|
|
|
|
|
node.get('num_parallel_workers'), node.get('seed'), node.get('num_shards'),
|
|
|
|
|
node.get('shard_id'), node.get('block_reader'))
|
|
|
|
|
node.get('shard_id'), node.get('block_reader'), sampler)
|
|
|
|
|
|
|
|
|
|
elif dataset_op == 'TFRecordDataset':
|
|
|
|
|
pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'),
|
|
|
|
@ -341,24 +345,25 @@ def create_node(node):
|
|
|
|
|
|
|
|
|
|
def construct_sampler(in_sampler):
|
|
|
|
|
"""Instantiate Sampler object based on the information from dictionary['sampler']"""
|
|
|
|
|
sampler_name = in_sampler['sampler_name']
|
|
|
|
|
sampler_module = in_sampler['sampler_module']
|
|
|
|
|
sampler_class = getattr(sys.modules[sampler_module], sampler_name)
|
|
|
|
|
sampler = None
|
|
|
|
|
if sampler_name == 'DistributedSampler':
|
|
|
|
|
sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle'))
|
|
|
|
|
elif sampler_name == 'PKSampler':
|
|
|
|
|
sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle'))
|
|
|
|
|
elif sampler_name == 'RandomSampler':
|
|
|
|
|
sampler = sampler_class(in_sampler.get('replacement'), in_sampler.get('num_samples'))
|
|
|
|
|
elif sampler_name == 'SequentialSampler':
|
|
|
|
|
sampler = sampler_class()
|
|
|
|
|
elif sampler_name == 'SubsetRandomSampler':
|
|
|
|
|
sampler = sampler_class(in_sampler['indices'])
|
|
|
|
|
elif sampler_name == 'WeightedRandomSampler':
|
|
|
|
|
sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement'))
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Sampler type is unknown: " + sampler_name)
|
|
|
|
|
if in_sampler is not None:
|
|
|
|
|
sampler_name = in_sampler['sampler_name']
|
|
|
|
|
sampler_module = in_sampler['sampler_module']
|
|
|
|
|
sampler_class = getattr(sys.modules[sampler_module], sampler_name)
|
|
|
|
|
if sampler_name == 'DistributedSampler':
|
|
|
|
|
sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle'))
|
|
|
|
|
elif sampler_name == 'PKSampler':
|
|
|
|
|
sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle'))
|
|
|
|
|
elif sampler_name == 'RandomSampler':
|
|
|
|
|
sampler = sampler_class(in_sampler.get('replacement'), in_sampler.get('num_samples'))
|
|
|
|
|
elif sampler_name == 'SequentialSampler':
|
|
|
|
|
sampler = sampler_class()
|
|
|
|
|
elif sampler_name == 'SubsetRandomSampler':
|
|
|
|
|
sampler = sampler_class(in_sampler['indices'])
|
|
|
|
|
elif sampler_name == 'WeightedRandomSampler':
|
|
|
|
|
sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement'))
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Sampler type is unknown: " + sampler_name)
|
|
|
|
|
|
|
|
|
|
return sampler
|
|
|
|
|
|
|
|
|
|