# Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Functions to support dataset serialize and deserialize. """ import json import os import sys from mindspore import log as logger from . import datasets as de from ..transforms.vision.utils import Inter, Border def serialize(dataset, json_filepath=None): """ Serialize dataset pipeline into a json file. Args: dataset (Dataset): the starting node. json_filepath (string): a filepath where a serialized json file will be generated. Returns: dict containing the serialized dataset graph. Raises: OSError cannot open a file Examples: >>> import mindspore.dataset as ds >>> import mindspore.dataset.transforms.c_transforms as C >>> DATA_DIR = "../../data/testMnistData" >>> data = ds.MnistDataset(DATA_DIR, 100) >>> one_hot_encode = C.OneHot(10) # num_classes is input argument >>> data = data.map(input_column_names="label", operation=one_hot_encode) >>> data = data.batch(batch_size=10, drop_remainder=True) >>> >>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json") # serialize it to json file >>> serialized_data = ds.engine.serialize(data) # serialize it to python dict """ serialized_pipeline = traverse(dataset) if json_filepath: with open(json_filepath, 'w') as json_file: json.dump(serialized_pipeline, json_file, indent=2) return serialized_pipeline def deserialize(input_dict=None, json_filepath=None): """ Construct a de pipeline from a json file produced by de.serialize(). Args: input_dict (dict): a python dictionary containing a serialized dataset graph json_filepath (string): a path to the json file. Returns: de.Dataset or None if error occurs. Raises: OSError cannot open a file. Examples: >>> import mindspore.dataset as ds >>> import mindspore.dataset.transforms.c_transforms as C >>> DATA_DIR = "../../data/testMnistData" >>> data = ds.MnistDataset(DATA_DIR, 100) >>> one_hot_encode = C.OneHot(10) # num_classes is input argument >>> data = data.map(input_column_names="label", operation=one_hot_encode) >>> data = data.batch(batch_size=10, drop_remainder=True) >>> >>> # Use case 1: to/from json file >>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json") >>> data = ds.engine.deserialize(json_filepath="mnist_dataset_pipeline.json") >>> # Use case 2: to/from python dictionary >>> serialized_data = ds.engine.serialize(data) >>> data = ds.engine.deserialize(input_dict=serialized_data) """ data = None if input_dict: data = construct_pipeline(input_dict) if json_filepath: dict_pipeline = dict() with open(json_filepath, 'r') as json_file: dict_pipeline = json.load(json_file) data = construct_pipeline(dict_pipeline) return data def expand_path(node_repr, key, val): """Convert relative to absolute path.""" if isinstance(val, list): node_repr[key] = [os.path.abspath(file) for file in val] else: node_repr[key] = os.path.abspath(val) def serialize_operations(node_repr, key, val): """Serialize tensor op (python object) to dictionary.""" if isinstance(val, list): node_repr[key] = [] for op in val: node_repr[key].append(op.__dict__) # Extracting module and name information from a python object # Example: tensor_op_module is 'minddata.transforms.c_transforms' and tensor_op_name is 'Decode' node_repr[key][-1]['tensor_op_name'] = type(op).__name__ node_repr[key][-1]['tensor_op_module'] = type(op).__module__ else: node_repr[key] = val.__dict__ node_repr[key]['tensor_op_name'] = type(val).__name__ node_repr[key]['tensor_op_module'] = type(val).__module__ def serialize_sampler(node_repr, val): """Serialize sampler object to dictionary.""" if val is None: node_repr['sampler'] = None else: node_repr['sampler'] = val.__dict__ node_repr['sampler']['sampler_module'] = type(val).__module__ node_repr['sampler']['sampler_name'] = type(val).__name__ def traverse(node): """Pre-order traverse the pipeline and capture the information as we go.""" # Node representation (node_repr) is a python dictionary that capture and store the # dataset pipeline information before dumping it to JSON or other format. node_repr = dict() node_repr['op_type'] = type(node).__name__ node_repr['op_module'] = type(node).__module__ # Start with an empty list of children, will be added later as we traverse this node. node_repr["children"] = [] # Retrieve the information about the current node. It should include arguments # passed to the node during object construction. node_args = node.get_args() for k, v in node_args.items(): # Store the information about this node into node_repr. # Further serialize the object in the arguments if needed. if k == 'operations': serialize_operations(node_repr, k, v) elif k == 'sampler': serialize_sampler(node_repr, v) elif k in set(['schema', 'dataset_files', 'dataset_dir', 'schema_file_path']): expand_path(node_repr, k, v) else: node_repr[k] = v # Leaf node doesn't have input attribute. if not node.input: return node_repr # Recursively traverse the child and assign it to the current node_repr['children']. for child in node.input: node_repr["children"].append(traverse(child)) return node_repr def show(dataset, indentation=2): """ Write the dataset pipeline graph onto logger.info. Args: dataset (Dataset): the starting node. indentation (int, optional): indentation used by the json print. Pass None to not indent. """ pipeline = traverse(dataset) logger.info(json.dumps(pipeline, indent=indentation)) def compare(pipeline1, pipeline2): """ Compare if two dataset pipelines are the same. Args: pipeline1 (Dataset): a dataset pipeline. pipeline2 (Dataset): a dataset pipeline. """ return traverse(pipeline1) == traverse(pipeline2) def construct_pipeline(node): """Construct the python Dataset objects by following the dictionary deserialized from json file.""" op_type = node.get('op_type') if not op_type: raise ValueError("op_type field in the json file can't be None.") # Instantiate python Dataset object based on the current dictionary element dataset = create_node(node) # Initially it is not connected to any other object. dataset.input = [] # Construct the children too and add edge between the children and parent. for child in node['children']: dataset.input.append(construct_pipeline(child)) return dataset def create_node(node): """Parse the key, value in the node dictionary and instantiate the python Dataset object""" logger.info('creating node: %s', node['op_type']) dataset_op = node['op_type'] op_module = node['op_module'] # Get the python class to be instantiated. # Example: # "op_type": "MapDataset", # "op_module": "mindspore.dataset.datasets", pyclass = getattr(sys.modules[op_module], dataset_op) pyobj = None # Find a matching Dataset class and call the constructor with the corresponding args. # When a new Dataset class is introduced, another if clause and parsing code needs to be added. if dataset_op == 'StorageDataset': pyobj = pyclass(node['dataset_files'], node['schema'], node.get('distribution'), node.get('columns_list'), node.get('num_parallel_workers')) elif dataset_op == 'ImageFolderDatasetV2': sampler = construct_sampler(node.get('sampler')) pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'), node.get('shuffle'), sampler, node.get('extensions'), node.get('class_indexing'), node.get('decode'), node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'RangeDataset': pyobj = pyclass(node['start'], node['stop'], node['step']) elif dataset_op == 'ImageFolderDataset': pyobj = pyclass(node['dataset_dir'], node['schema'], node.get('distribution'), node.get('column_list'), node.get('num_parallel_workers'), node.get('deterministic_output'), node.get('prefetch_size'), node.get('labels_filename'), node.get('dataset_usage')) elif dataset_op == 'MnistDataset': sampler = construct_sampler(node.get('sampler')) pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'), node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'MindDataset': sampler = construct_sampler(node.get('sampler')) pyobj = pyclass(node['dataset_file'], node.get('columns_list'), node.get('num_parallel_workers'), node.get('seed'), node.get('num_shards'), node.get('shard_id'), node.get('block_reader'), sampler) elif dataset_op == 'TFRecordDataset': pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'), node.get('num_samples'), node.get('num_parallel_workers'), de.Shuffle(node.get('shuffle')), node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'ManifestDataset': sampler = construct_sampler(node.get('sampler')) pyobj = pyclass(node['dataset_file'], node['usage'], node.get('num_samples'), node.get('num_parallel_workers'), node.get('shuffle'), sampler, node.get('class_indexing'), node.get('decode'), node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'Cifar10Dataset': sampler = construct_sampler(node.get('sampler')) pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'), node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'Cifar100Dataset': sampler = construct_sampler(node.get('sampler')) pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'), node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'VOCDataset': sampler = construct_sampler(node.get('sampler')) pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'), node.get('shuffle'), node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'CelebADataset': sampler = construct_sampler(node.get('sampler')) pyobj = pyclass(node['dataset_dir'], node.get('num_parallel_workers'), node.get('shuffle'), node.get('dataset_type'), sampler, node.get('decode'), node.get('extensions'), node.get('num_samples'), sampler, node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'GeneratorDataset': # Serializing py function can be done using marshal library raise RuntimeError(dataset_op + " is not yet supported") elif dataset_op == 'RepeatDataset': pyobj = de.Dataset().repeat(node.get('count')) elif dataset_op == 'SkipDataset': pyobj = de.Dataset().skip(node.get('count')) elif dataset_op == 'TakeDataset': pyobj = de.Dataset().take(node.get('count')) elif dataset_op == 'MapDataset': tensor_ops = construct_tensor_ops(node.get('operations')) pyobj = de.Dataset().map(node.get('input_columns'), tensor_ops, node.get('output_columns'), node.get('columns_order'), node.get('num_parallel_workers')) elif dataset_op == 'ShuffleDataset': pyobj = de.Dataset().shuffle(node.get('buffer_size')) elif dataset_op == 'BatchDataset': pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder')) elif dataset_op == 'CacheDataset': # Member function cache() is not defined in class Dataset yet. raise RuntimeError(dataset_op + " is not yet supported") elif dataset_op == 'FilterDataset': # Member function filter() is not defined in class Dataset yet. raise RuntimeError(dataset_op + " is not yet supported") elif dataset_op == 'TakeDataset': # Member function take() is not defined in class Dataset yet. raise RuntimeError(dataset_op + " is not yet supported") elif dataset_op == 'ZipDataset': # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller. pyobj = de.ZipDataset((de.Dataset(), de.Dataset())) elif dataset_op == 'RenameDataset': pyobj = de.Dataset().rename(node['input_columns'], node['output_columns']) elif dataset_op == 'ProjectDataset': pyobj = de.Dataset().project(node['columns']) elif dataset_op == 'TransferDataset': pyobj = de.Dataset().to_device() else: raise RuntimeError(dataset_op + " is not yet supported by ds.engine.deserialize()") return pyobj def construct_sampler(in_sampler): """Instantiate Sampler object based on the information from dictionary['sampler']""" sampler = None if in_sampler is not None: sampler_name = in_sampler['sampler_name'] sampler_module = in_sampler['sampler_module'] sampler_class = getattr(sys.modules[sampler_module], sampler_name) if sampler_name == 'DistributedSampler': sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle')) elif sampler_name == 'PKSampler': sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle')) elif sampler_name == 'RandomSampler': sampler = sampler_class(in_sampler.get('replacement'), in_sampler.get('num_samples')) elif sampler_name == 'SequentialSampler': sampler = sampler_class() elif sampler_name == 'SubsetRandomSampler': sampler = sampler_class(in_sampler['indices']) elif sampler_name == 'WeightedRandomSampler': sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement')) else: raise ValueError("Sampler type is unknown: " + sampler_name) return sampler def construct_tensor_ops(operations): """Instantiate tensor op object(s) based on the information from dictionary['operations']""" result = [] for op in operations: op_module = op['tensor_op_module'] op_name = op['tensor_op_name'] op_class = getattr(sys.modules[op_module], op_name) if op_name == 'Decode': result.append(op_class(op.get('rgb'))) elif op_name == 'Normalize': result.append(op_class(op['mean'], op['std'])) elif op_name == 'RandomCrop': result.append(op_class(op['size'], op.get('padding'), op.get('pad_if_needed'), op.get('fill_value'), Border(op.get('padding_mode')))) elif op_name == 'RandomHorizontalFlip': result.append(op_class(op.get('prob'))) elif op_name == 'RandomVerticalFlip': result.append(op_class(op.get('prob'))) elif op_name == 'Resize': result.append(op_class(op['size'], Inter(op.get('interpolation')))) elif op_name == 'RandomResizedCrop': result.append(op_class(op['size'], op.get('scale'), op.get('ratio'), Inter(op.get('interpolation')), op.get('max_attempts'))) elif op_name == 'CenterCrop': result.append(op_class(op['size'])) elif op_name == 'RandomColorAdjust': result.append(op_class(op.get('brightness'), op.get('contrast'), op.get('saturation'), op.get('hue'))) elif op_name == 'RandomRotation': result.append(op_class(op['degree'], op.get('resample'), op.get('expand'), op.get('center'), op.get('fill_value'))) elif op_name == 'Rescale': result.append(op_class(op['rescale'], op['shift'])) elif op_name == 'RandomResize': result.append(op_class(op['size'])) elif op_name == 'TypeCast': result.append(op_class(op['data_type'])) elif op_name == 'HWC2CHW': result.append(op_class()) elif op_name == 'CHW2HWC': raise ValueError("Tensor op is not supported: " + op_name) elif op_name == 'OneHot': result.append(op_class(op['num_classes'])) elif op_name == 'RandomCropDecodeResize': result.append(op_class(op['size'], op.get('scale'), op.get('ratio'), Inter(op.get('interpolation')), op.get('max_attempts'))) elif op_name == 'Pad': result.append(op_class(op['padding'], op['fill_value'], Border(op['padding_mode']))) else: raise ValueError("Tensor op name is unknown: " + op_name) return result