You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
472 lines
19 KiB
472 lines
19 KiB
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""
|
|
Functions to support dataset serialize and deserialize.
|
|
"""
|
|
import json
|
|
import os
|
|
import sys
|
|
|
|
from mindspore import log as logger
|
|
from . import datasets as de
|
|
from ..transforms.vision.utils import Inter, Border
|
|
|
|
|
|
def serialize(dataset, json_filepath=None):
|
|
"""
|
|
Serialize dataset pipeline into a json file.
|
|
|
|
Args:
|
|
dataset (Dataset): the starting node.
|
|
json_filepath (string): a filepath where a serialized json file will be generated.
|
|
|
|
Returns:
|
|
dict containing the serialized dataset graph.
|
|
|
|
Raises:
|
|
OSError cannot open a file
|
|
|
|
Examples:
|
|
>>> import mindspore.dataset as ds
|
|
>>> import mindspore.dataset.transforms.c_transforms as C
|
|
>>> DATA_DIR = "../../data/testMnistData"
|
|
>>> data = ds.MnistDataset(DATA_DIR, 100)
|
|
>>> one_hot_encode = C.OneHot(10) # num_classes is input argument
|
|
>>> data = data.map(input_column_names="label", operation=one_hot_encode)
|
|
>>> data = data.batch(batch_size=10, drop_remainder=True)
|
|
>>>
|
|
>>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json") # serialize it to json file
|
|
>>> serialized_data = ds.engine.serialize(data) # serialize it to python dict
|
|
"""
|
|
serialized_pipeline = traverse(dataset)
|
|
if json_filepath:
|
|
with open(json_filepath, 'w') as json_file:
|
|
json.dump(serialized_pipeline, json_file, indent=2)
|
|
return serialized_pipeline
|
|
|
|
|
|
def deserialize(input_dict=None, json_filepath=None):
|
|
"""
|
|
Construct a de pipeline from a json file produced by de.serialize().
|
|
|
|
Args:
|
|
input_dict (dict): a python dictionary containing a serialized dataset graph
|
|
json_filepath (string): a path to the json file.
|
|
|
|
Returns:
|
|
de.Dataset or None if error occurs.
|
|
|
|
Raises:
|
|
OSError cannot open a file.
|
|
|
|
Examples:
|
|
>>> import mindspore.dataset as ds
|
|
>>> import mindspore.dataset.transforms.c_transforms as C
|
|
>>> DATA_DIR = "../../data/testMnistData"
|
|
>>> data = ds.MnistDataset(DATA_DIR, 100)
|
|
>>> one_hot_encode = C.OneHot(10) # num_classes is input argument
|
|
>>> data = data.map(input_column_names="label", operation=one_hot_encode)
|
|
>>> data = data.batch(batch_size=10, drop_remainder=True)
|
|
>>>
|
|
>>> # Use case 1: to/from json file
|
|
>>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json")
|
|
>>> data = ds.engine.deserialize(json_filepath="mnist_dataset_pipeline.json")
|
|
>>> # Use case 2: to/from python dictionary
|
|
>>> serialized_data = ds.engine.serialize(data)
|
|
>>> data = ds.engine.deserialize(input_dict=serialized_data)
|
|
|
|
"""
|
|
data = None
|
|
if input_dict:
|
|
data = construct_pipeline(input_dict)
|
|
|
|
if json_filepath:
|
|
dict_pipeline = dict()
|
|
with open(json_filepath, 'r') as json_file:
|
|
dict_pipeline = json.load(json_file)
|
|
data = construct_pipeline(dict_pipeline)
|
|
|
|
return data
|
|
|
|
|
|
def expand_path(node_repr, key, val):
|
|
"""Convert relative to absolute path."""
|
|
if isinstance(val, list):
|
|
node_repr[key] = [os.path.abspath(file) for file in val]
|
|
else:
|
|
node_repr[key] = os.path.abspath(val)
|
|
|
|
|
|
def serialize_operations(node_repr, key, val):
|
|
"""Serialize tensor op (python object) to dictionary."""
|
|
if isinstance(val, list):
|
|
node_repr[key] = []
|
|
for op in val:
|
|
node_repr[key].append(op.__dict__)
|
|
# Extracting module and name information from a python object
|
|
# Example: tensor_op_module is 'minddata.transforms.c_transforms' and tensor_op_name is 'Decode'
|
|
node_repr[key][-1]['tensor_op_name'] = type(op).__name__
|
|
node_repr[key][-1]['tensor_op_module'] = type(op).__module__
|
|
else:
|
|
node_repr[key] = val.__dict__
|
|
node_repr[key]['tensor_op_name'] = type(val).__name__
|
|
node_repr[key]['tensor_op_module'] = type(val).__module__
|
|
|
|
|
|
def serialize_sampler(node_repr, val):
|
|
"""Serialize sampler object to dictionary."""
|
|
if val is None:
|
|
node_repr['sampler'] = None
|
|
else:
|
|
node_repr['sampler'] = val.__dict__
|
|
node_repr['sampler']['sampler_module'] = type(val).__module__
|
|
node_repr['sampler']['sampler_name'] = type(val).__name__
|
|
|
|
|
|
def traverse(node):
|
|
"""Pre-order traverse the pipeline and capture the information as we go."""
|
|
# Node representation (node_repr) is a python dictionary that capture and store the
|
|
# dataset pipeline information before dumping it to JSON or other format.
|
|
node_repr = dict()
|
|
node_repr['op_type'] = type(node).__name__
|
|
node_repr['op_module'] = type(node).__module__
|
|
|
|
# Start with an empty list of children, will be added later as we traverse this node.
|
|
node_repr["children"] = []
|
|
|
|
# Retrieve the information about the current node. It should include arguments
|
|
# passed to the node during object construction.
|
|
node_args = node.get_args()
|
|
for k, v in node_args.items():
|
|
# Store the information about this node into node_repr.
|
|
# Further serialize the object in the arguments if needed.
|
|
if k == 'operations':
|
|
serialize_operations(node_repr, k, v)
|
|
elif k == 'sampler':
|
|
serialize_sampler(node_repr, v)
|
|
# return schema json str if its type is mindspore.dataset.Schema
|
|
elif k == 'schema' and isinstance(v, de.Schema):
|
|
node_repr[k] = v.to_json()
|
|
elif k in set(['schema', 'dataset_files', 'dataset_dir', 'schema_file_path']):
|
|
expand_path(node_repr, k, v)
|
|
else:
|
|
node_repr[k] = v
|
|
|
|
# If a sampler exists in this node, then the following 4 arguments must be set to None:
|
|
# num_samples, shard_id, num_shards, shuffle
|
|
# These arguments get moved into the sampler itself, so they are no longer needed to
|
|
# be set at the dataset level.
|
|
if 'sampler' in node_args.keys():
|
|
if 'num_samples' in node_repr.keys():
|
|
node_repr['num_samples'] = None
|
|
if 'shuffle' in node_repr.keys():
|
|
node_repr['shuffle'] = None
|
|
if 'num_shards' in node_repr.keys():
|
|
node_repr['num_shards'] = None
|
|
if 'shard_id' in node_repr.keys():
|
|
node_repr['shard_id'] = None
|
|
|
|
# Leaf node doesn't have input attribute.
|
|
if not node.input:
|
|
return node_repr
|
|
|
|
# Recursively traverse the child and assign it to the current node_repr['children'].
|
|
for child in node.input:
|
|
node_repr["children"].append(traverse(child))
|
|
|
|
return node_repr
|
|
|
|
|
|
def show(dataset, indentation=2):
|
|
"""
|
|
Write the dataset pipeline graph onto logger.info.
|
|
|
|
Args:
|
|
dataset (Dataset): the starting node.
|
|
indentation (int, optional): indentation used by the json print. Pass None to not indent.
|
|
"""
|
|
|
|
pipeline = traverse(dataset)
|
|
logger.info(json.dumps(pipeline, indent=indentation))
|
|
|
|
|
|
def compare(pipeline1, pipeline2):
|
|
"""
|
|
Compare if two dataset pipelines are the same.
|
|
|
|
Args:
|
|
pipeline1 (Dataset): a dataset pipeline.
|
|
pipeline2 (Dataset): a dataset pipeline.
|
|
"""
|
|
|
|
return traverse(pipeline1) == traverse(pipeline2)
|
|
|
|
|
|
def construct_pipeline(node):
|
|
"""Construct the python Dataset objects by following the dictionary deserialized from json file."""
|
|
op_type = node.get('op_type')
|
|
if not op_type:
|
|
raise ValueError("op_type field in the json file can't be None.")
|
|
|
|
# Instantiate python Dataset object based on the current dictionary element
|
|
dataset = create_node(node)
|
|
# Initially it is not connected to any other object.
|
|
dataset.input = []
|
|
|
|
# Construct the children too and add edge between the children and parent.
|
|
for child in node['children']:
|
|
dataset.input.append(construct_pipeline(child))
|
|
|
|
return dataset
|
|
|
|
|
|
def create_node(node):
|
|
"""Parse the key, value in the node dictionary and instantiate the python Dataset object"""
|
|
logger.info('creating node: %s', node['op_type'])
|
|
dataset_op = node['op_type']
|
|
op_module = node['op_module']
|
|
|
|
# Get the python class to be instantiated.
|
|
# Example:
|
|
# "op_type": "MapDataset",
|
|
# "op_module": "mindspore.dataset.datasets",
|
|
pyclass = getattr(sys.modules[op_module], dataset_op)
|
|
|
|
pyobj = None
|
|
# Find a matching Dataset class and call the constructor with the corresponding args.
|
|
# When a new Dataset class is introduced, another if clause and parsing code needs to be added.
|
|
if dataset_op == 'ImageFolderDatasetV2':
|
|
sampler = construct_sampler(node.get('sampler'))
|
|
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
|
|
node.get('shuffle'), sampler, node.get('extensions'),
|
|
node.get('class_indexing'), node.get('decode'), node.get('num_shards'),
|
|
node.get('shard_id'))
|
|
|
|
elif dataset_op == 'RangeDataset':
|
|
pyobj = pyclass(node['start'], node['stop'], node['step'])
|
|
|
|
elif dataset_op == 'ImageFolderDataset':
|
|
pyobj = pyclass(node['dataset_dir'], node['schema'], node.get('distribution'),
|
|
node.get('column_list'), node.get('num_parallel_workers'),
|
|
node.get('deterministic_output'), node.get('prefetch_size'),
|
|
node.get('labels_filename'), node.get('dataset_usage'))
|
|
|
|
elif dataset_op == 'MnistDataset':
|
|
sampler = construct_sampler(node.get('sampler'))
|
|
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
|
|
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
|
|
|
|
elif dataset_op == 'MindDataset':
|
|
sampler = construct_sampler(node.get('sampler'))
|
|
pyobj = pyclass(node['dataset_file'], node.get('columns_list'),
|
|
node.get('num_parallel_workers'), node.get('seed'), node.get('num_shards'),
|
|
node.get('shard_id'), node.get('block_reader'), sampler)
|
|
|
|
elif dataset_op == 'TFRecordDataset':
|
|
pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'),
|
|
node.get('num_samples'), node.get('num_parallel_workers'),
|
|
de.Shuffle(node.get('shuffle')), node.get('num_shards'), node.get('shard_id'))
|
|
|
|
elif dataset_op == 'ManifestDataset':
|
|
sampler = construct_sampler(node.get('sampler'))
|
|
pyobj = pyclass(node['dataset_file'], node['usage'], node.get('num_samples'),
|
|
node.get('num_parallel_workers'), node.get('shuffle'), sampler,
|
|
node.get('class_indexing'), node.get('decode'), node.get('num_shards'),
|
|
node.get('shard_id'))
|
|
|
|
elif dataset_op == 'Cifar10Dataset':
|
|
sampler = construct_sampler(node.get('sampler'))
|
|
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
|
|
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
|
|
|
|
elif dataset_op == 'Cifar100Dataset':
|
|
sampler = construct_sampler(node.get('sampler'))
|
|
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
|
|
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
|
|
|
|
elif dataset_op == 'VOCDataset':
|
|
sampler = construct_sampler(node.get('sampler'))
|
|
pyobj = pyclass(node['dataset_dir'], node.get('task'), node.get('mode'), node.get('class_indexing'),
|
|
node.get('num_samples'), node.get('num_parallel_workers'), node.get('shuffle'),
|
|
node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id'))
|
|
|
|
elif dataset_op == 'CocoDataset':
|
|
sampler = construct_sampler(node.get('sampler'))
|
|
pyobj = pyclass(node['dataset_dir'], node.get('annotation_file'), node.get('task'), node.get('num_samples'),
|
|
node.get('num_parallel_workers'), node.get('shuffle'), node.get('decode'), sampler,
|
|
node.get('num_shards'), node.get('shard_id'))
|
|
|
|
elif dataset_op == 'CelebADataset':
|
|
sampler = construct_sampler(node.get('sampler'))
|
|
pyobj = pyclass(node['dataset_dir'], node.get('num_parallel_workers'), node.get('shuffle'),
|
|
node.get('dataset_type'), sampler, node.get('decode'), node.get('extensions'),
|
|
node.get('num_samples'), sampler, node.get('num_shards'), node.get('shard_id'))
|
|
|
|
elif dataset_op == 'GeneratorDataset':
|
|
# Serializing py function can be done using marshal library
|
|
raise RuntimeError(dataset_op + " is not yet supported")
|
|
|
|
elif dataset_op == 'RepeatDataset':
|
|
pyobj = de.Dataset().repeat(node.get('count'))
|
|
|
|
elif dataset_op == 'SkipDataset':
|
|
pyobj = de.Dataset().skip(node.get('count'))
|
|
|
|
elif dataset_op == 'TakeDataset':
|
|
pyobj = de.Dataset().take(node.get('count'))
|
|
|
|
elif dataset_op == 'MapDataset':
|
|
tensor_ops = construct_tensor_ops(node.get('operations'))
|
|
pyobj = de.Dataset().map(node.get('input_columns'), tensor_ops, node.get('output_columns'),
|
|
node.get('columns_order'), node.get('num_parallel_workers'))
|
|
|
|
elif dataset_op == 'ShuffleDataset':
|
|
pyobj = de.Dataset().shuffle(node.get('buffer_size'))
|
|
|
|
elif dataset_op == 'BatchDataset':
|
|
pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder'))
|
|
|
|
elif dataset_op == 'CacheDataset':
|
|
# Member function cache() is not defined in class Dataset yet.
|
|
raise RuntimeError(dataset_op + " is not yet supported")
|
|
|
|
elif dataset_op == 'FilterDataset':
|
|
# Member function filter() is not defined in class Dataset yet.
|
|
raise RuntimeError(dataset_op + " is not yet supported")
|
|
|
|
elif dataset_op == 'TakeDataset':
|
|
# Member function take() is not defined in class Dataset yet.
|
|
raise RuntimeError(dataset_op + " is not yet supported")
|
|
|
|
elif dataset_op == 'ZipDataset':
|
|
# Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller.
|
|
pyobj = de.ZipDataset((de.Dataset(), de.Dataset()))
|
|
|
|
elif dataset_op == 'ConcatDataset':
|
|
# Create ConcatDataset instance, giving dummy input dataset that will be overrided in the caller.
|
|
pyobj = de.ConcatDataset((de.Dataset(), de.Dataset()))
|
|
|
|
elif dataset_op == 'RenameDataset':
|
|
pyobj = de.Dataset().rename(node['input_columns'], node['output_columns'])
|
|
|
|
elif dataset_op == 'ProjectDataset':
|
|
pyobj = de.Dataset().project(node['columns'])
|
|
|
|
elif dataset_op == 'TransferDataset':
|
|
pyobj = de.Dataset().to_device()
|
|
|
|
else:
|
|
raise RuntimeError(dataset_op + " is not yet supported by ds.engine.deserialize()")
|
|
|
|
return pyobj
|
|
|
|
|
|
def construct_sampler(in_sampler):
|
|
"""Instantiate Sampler object based on the information from dictionary['sampler']"""
|
|
sampler = None
|
|
if in_sampler is not None:
|
|
sampler_name = in_sampler['sampler_name']
|
|
sampler_module = in_sampler['sampler_module']
|
|
sampler_class = getattr(sys.modules[sampler_module], sampler_name)
|
|
if sampler_name == 'DistributedSampler':
|
|
sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle'))
|
|
elif sampler_name == 'PKSampler':
|
|
sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle'))
|
|
elif sampler_name == 'RandomSampler':
|
|
sampler = sampler_class(in_sampler.get('replacement'), in_sampler.get('num_samples'))
|
|
elif sampler_name == 'SequentialSampler':
|
|
sampler = sampler_class()
|
|
elif sampler_name == 'SubsetRandomSampler':
|
|
sampler = sampler_class(in_sampler['indices'])
|
|
elif sampler_name == 'WeightedRandomSampler':
|
|
sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement'))
|
|
else:
|
|
raise ValueError("Sampler type is unknown: " + sampler_name)
|
|
|
|
return sampler
|
|
|
|
|
|
def construct_tensor_ops(operations):
|
|
"""Instantiate tensor op object(s) based on the information from dictionary['operations']"""
|
|
result = []
|
|
for op in operations:
|
|
op_module = op['tensor_op_module']
|
|
op_name = op['tensor_op_name']
|
|
op_class = getattr(sys.modules[op_module], op_name)
|
|
|
|
if op_name == 'Decode':
|
|
result.append(op_class(op.get('rgb')))
|
|
|
|
elif op_name == 'Normalize':
|
|
result.append(op_class(op['mean'], op['std']))
|
|
|
|
elif op_name == 'RandomCrop':
|
|
result.append(op_class(op['size'], op.get('padding'), op.get('pad_if_needed'),
|
|
op.get('fill_value'), Border(op.get('padding_mode'))))
|
|
|
|
elif op_name == 'RandomHorizontalFlip':
|
|
result.append(op_class(op.get('prob')))
|
|
|
|
elif op_name == 'RandomVerticalFlip':
|
|
result.append(op_class(op.get('prob')))
|
|
|
|
elif op_name == 'Resize':
|
|
result.append(op_class(op['size'], Inter(op.get('interpolation'))))
|
|
|
|
elif op_name == 'RandomResizedCrop':
|
|
result.append(op_class(op['size'], op.get('scale'), op.get('ratio'),
|
|
Inter(op.get('interpolation')), op.get('max_attempts')))
|
|
|
|
elif op_name == 'CenterCrop':
|
|
result.append(op_class(op['size']))
|
|
|
|
elif op_name == 'RandomColorAdjust':
|
|
result.append(op_class(op.get('brightness'), op.get('contrast'), op.get('saturation'),
|
|
op.get('hue')))
|
|
|
|
elif op_name == 'RandomRotation':
|
|
result.append(op_class(op['degree'], op.get('resample'), op.get('expand'),
|
|
op.get('center'), op.get('fill_value')))
|
|
|
|
elif op_name == 'Rescale':
|
|
result.append(op_class(op['rescale'], op['shift']))
|
|
|
|
elif op_name == 'RandomResize':
|
|
result.append(op_class(op['size']))
|
|
|
|
elif op_name == 'TypeCast':
|
|
result.append(op_class(op['data_type']))
|
|
|
|
elif op_name == 'HWC2CHW':
|
|
result.append(op_class())
|
|
|
|
elif op_name == 'CHW2HWC':
|
|
raise ValueError("Tensor op is not supported: " + op_name)
|
|
|
|
elif op_name == 'OneHot':
|
|
result.append(op_class(op['num_classes']))
|
|
|
|
elif op_name == 'RandomCropDecodeResize':
|
|
result.append(op_class(op['size'], op.get('scale'), op.get('ratio'),
|
|
Inter(op.get('interpolation')), op.get('max_attempts')))
|
|
|
|
elif op_name == 'Pad':
|
|
result.append(op_class(op['padding'], op['fill_value'], Border(op['padding_mode'])))
|
|
|
|
else:
|
|
raise ValueError("Tensor op name is unknown: " + op_name)
|
|
|
|
return result
|