From 4bbc3445f129e9358b574f4b7157ae657245cac2 Mon Sep 17 00:00:00 2001 From: hesham Date: Thu, 4 Feb 2021 19:12:52 -0500 Subject: [PATCH] datasets.py cleanup --- .../dataset/include/datasets_bindings.cc | 154 +- .../engine/ir/datasetops/dataset_node.cc | 5 + .../engine/ir/datasetops/dataset_node.h | 5 + mindspore/dataset/core/config.py | 2 +- mindspore/dataset/engine/datasets.py | 1414 ++++------------- mindspore/dataset/engine/samplers.py | 2 +- mindspore/dataset/transforms/py_transforms.py | 29 + 7 files changed, 381 insertions(+), 1230 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc index 8a94ab4a5c..eb8ba07c09 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/datasets_bindings.cc @@ -82,10 +82,15 @@ namespace dataset { PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) { (void)py::class_>(*m, "Dataset") - .def("SetNumWorkers", + .def("set_num_workers", [](std::shared_ptr self, std::optional num_workers) { return num_workers ? self->SetNumWorkers(*num_workers) : self; }) + .def("set_cache_client", + [](std::shared_ptr self) { + std::shared_ptr dc = nullptr; + return self->SetDatasetCache(dc); + }) .def( "Zip", [](std::shared_ptr self, py::list datasets) { @@ -109,10 +114,9 @@ PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) { (void)py::class_>(*m, "CelebANode", "to create a CelebANode") .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, bool decode, - py::list extensions, std::shared_ptr cc) { - auto celebA = - std::make_shared(dataset_dir, usage, toSamplerObj(sampler), decode, - toStringSet(extensions), toDatasetCache(std::move(cc))); + py::list extensions) { + auto celebA = std::make_shared(dataset_dir, usage, toSamplerObj(sampler), decode, + toStringSet(extensions), nullptr); THROW_IF_ERROR(celebA->ValidateParams()); return celebA; })); @@ -121,10 +125,8 @@ PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) { PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) { (void)py::class_>(*m, "Cifar10Node", "to create a Cifar10Node") - .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, - std::shared_ptr cc) { - auto cifar10 = std::make_shared(dataset_dir, usage, toSamplerObj(sampler), - toDatasetCache(std::move(cc))); + .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { + auto cifar10 = std::make_shared(dataset_dir, usage, toSamplerObj(sampler), nullptr); THROW_IF_ERROR(cifar10->ValidateParams()); return cifar10; })); @@ -133,36 +135,34 @@ PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) { PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) { (void)py::class_>(*m, "Cifar100Node", "to create a Cifar100Node") - .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, - std::shared_ptr cc) { - auto cifar100 = std::make_shared(dataset_dir, usage, toSamplerObj(sampler), - toDatasetCache(std::move(cc))); + .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { + auto cifar100 = + std::make_shared(dataset_dir, usage, toSamplerObj(sampler), nullptr); THROW_IF_ERROR(cifar100->ValidateParams()); return cifar100; })); })); -PYBIND_REGISTER( - CLUENode, 2, ([](const py::module *m) { - (void)py::class_>(*m, "CLUENode", "to create a CLUENode") - .def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, int32_t shuffle, - int32_t num_shards, int32_t shard_id, std::shared_ptr cc) { - std::shared_ptr clue_node = - std::make_shared(toStringVector(files), task, usage, num_samples, toShuffleMode(shuffle), - num_shards, shard_id, toDatasetCache(std::move(cc))); - THROW_IF_ERROR(clue_node->ValidateParams()); - return clue_node; - })); - })); +PYBIND_REGISTER(CLUENode, 2, ([](const py::module *m) { + (void)py::class_>(*m, "CLUENode", + "to create a CLUENode") + .def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, + int32_t shuffle, int32_t num_shards, int32_t shard_id) { + std::shared_ptr clue_node = + std::make_shared(toStringVector(files), task, usage, num_samples, + toShuffleMode(shuffle), num_shards, shard_id, nullptr); + THROW_IF_ERROR(clue_node->ValidateParams()); + return clue_node; + })); + })); PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) { (void)py::class_>(*m, "CocoNode", "to create a CocoNode") .def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task, - bool decode, py::handle sampler, std::shared_ptr cc) { - std::shared_ptr coco = - std::make_shared(dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), - toDatasetCache(std::move(cc))); + bool decode, py::handle sampler) { + std::shared_ptr coco = std::make_shared( + dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr); THROW_IF_ERROR(coco->ValidateParams()); return coco; })); @@ -172,10 +172,10 @@ PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) { (void)py::class_>(*m, "CSVNode", "to create a CSVNode") .def(py::init([](std::vector csv_files, char field_delim, py::list column_defaults, std::vector column_names, int64_t num_samples, int32_t shuffle, - int32_t num_shards, int32_t shard_id, std::shared_ptr cc) { - auto csv = std::make_shared(csv_files, field_delim, toCSVBase(column_defaults), - column_names, num_samples, toShuffleMode(shuffle), - num_shards, shard_id, toDatasetCache(std::move(cc))); + int32_t num_shards, int32_t shard_id) { + auto csv = + std::make_shared(csv_files, field_delim, toCSVBase(column_defaults), column_names, + num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); THROW_IF_ERROR(csv->ValidateParams()); return csv; })); @@ -205,12 +205,12 @@ PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) { (void)py::class_>( *m, "ImageFolderNode", "to create an ImageFolderNode") .def(py::init([](std::string dataset_dir, bool decode, py::handle sampler, py::list extensions, - py::dict class_indexing, std::shared_ptr cc) { + py::dict class_indexing) { // Don't update recursive to true bool recursive = false; // Will be removed in future PR - auto imagefolder = std::make_shared( - dataset_dir, decode, toSamplerObj(sampler), recursive, toStringSet(extensions), - toStringMap(class_indexing), toDatasetCache(std::move(cc))); + auto imagefolder = std::make_shared(dataset_dir, decode, toSamplerObj(sampler), + recursive, toStringSet(extensions), + toStringMap(class_indexing), nullptr); THROW_IF_ERROR(imagefolder->ValidateParams()); return imagefolder; })); @@ -220,10 +220,9 @@ PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) { (void)py::class_>(*m, "ManifestNode", "to create a ManifestNode") .def(py::init([](std::string dataset_file, std::string usage, py::handle sampler, - py::dict class_indexing, bool decode, std::shared_ptr cc) { + py::dict class_indexing, bool decode) { auto manifest = std::make_shared(dataset_file, usage, toSamplerObj(sampler), - toStringMap(class_indexing), decode, - toDatasetCache(std::move(cc))); + toStringMap(class_indexing), decode, nullptr); THROW_IF_ERROR(manifest->ValidateParams()); return manifest; })); @@ -261,41 +260,38 @@ PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) { PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) { (void)py::class_>(*m, "MnistNode", "to create an MnistNode") - .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, - std::shared_ptr cc) { - auto mnist = std::make_shared(dataset_dir, usage, toSamplerObj(sampler), - toDatasetCache(std::move(cc))); + .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) { + auto mnist = std::make_shared(dataset_dir, usage, toSamplerObj(sampler), nullptr); THROW_IF_ERROR(mnist->ValidateParams()); return mnist; })); })); -PYBIND_REGISTER( - RandomNode, 2, ([](const py::module *m) { - (void)py::class_>(*m, "RandomNode", "to create a RandomNode") - .def(py::init([](int32_t total_rows, std::shared_ptr schema, py::list columns_list, - std::shared_ptr cc) { - auto random_node = - std::make_shared(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc))); - THROW_IF_ERROR(random_node->ValidateParams()); - return random_node; - })) - .def(py::init([](int32_t total_rows, std::string schema, py::list columns_list, std::shared_ptr cc) { - auto random_node = - std::make_shared(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc))); - THROW_IF_ERROR(random_node->ValidateParams()); - return random_node; - })); - })); +PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) { + (void)py::class_>(*m, "RandomNode", + "to create a RandomNode") + .def(py::init([](int32_t total_rows, std::shared_ptr schema, py::list columns_list) { + auto random_node = + std::make_shared(total_rows, schema, toStringVector(columns_list), nullptr); + THROW_IF_ERROR(random_node->ValidateParams()); + return random_node; + })) + .def(py::init([](int32_t total_rows, std::string schema, py::list columns_list) { + auto random_node = + std::make_shared(total_rows, schema, toStringVector(columns_list), nullptr); + THROW_IF_ERROR(random_node->ValidateParams()); + return random_node; + })); + })); PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) { (void)py::class_>(*m, "TextFileNode", "to create a TextFileNode") .def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards, - int32_t shard_id, std::shared_ptr cc) { - std::shared_ptr textfile_node = std::make_shared( - toStringVector(dataset_files), num_samples, toShuffleMode(shuffle), num_shards, shard_id, - toDatasetCache(std::move(cc))); + int32_t shard_id) { + std::shared_ptr textfile_node = + std::make_shared(toStringVector(dataset_files), num_samples, + toShuffleMode(shuffle), num_shards, shard_id, nullptr); THROW_IF_ERROR(textfile_node->ValidateParams()); return textfile_node; })); @@ -306,19 +302,19 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) { "to create a TFRecordNode") .def(py::init([](py::list dataset_files, std::shared_ptr schema, py::list columns_list, int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id, - bool shard_equal_rows, std::shared_ptr cc) { + bool shard_equal_rows) { std::shared_ptr tfrecord = std::make_shared( toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples, - toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, toDatasetCache(std::move(cc))); + toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr); THROW_IF_ERROR(tfrecord->ValidateParams()); return tfrecord; })) .def(py::init([](py::list dataset_files, std::string schema, py::list columns_list, int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id, - bool shard_equal_rows, std::shared_ptr cc) { + bool shard_equal_rows) { std::shared_ptr tfrecord = std::make_shared( toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples, - toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, toDatasetCache(std::move(cc))); + toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr); THROW_IF_ERROR(tfrecord->ValidateParams()); return tfrecord; })); @@ -326,15 +322,13 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) { PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) { (void)py::class_>(*m, "VOCNode", "to create a VOCNode") - .def( - py::init([](std::string dataset_dir, std::string task, std::string usage, py::dict class_indexing, - bool decode, py::handle sampler, std::shared_ptr cc) { - std::shared_ptr voc = - std::make_shared(dataset_dir, task, usage, toStringMap(class_indexing), decode, - toSamplerObj(sampler), toDatasetCache(std::move(cc))); - THROW_IF_ERROR(voc->ValidateParams()); - return voc; - })); + .def(py::init([](std::string dataset_dir, std::string task, std::string usage, + py::dict class_indexing, bool decode, py::handle sampler) { + std::shared_ptr voc = std::make_shared( + dataset_dir, task, usage, toStringMap(class_indexing), decode, toSamplerObj(sampler), nullptr); + THROW_IF_ERROR(voc->ValidateParams()); + return voc; + })); })); // PYBIND FOR NON-LEAF NODES @@ -439,11 +433,11 @@ PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) { PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) { (void)py::class_>(*m, "MapNode", "to create a MapNode") .def(py::init([](std::shared_ptr self, py::list operations, py::list input_columns, - py::list output_columns, py::list project_columns, std::shared_ptr cc, + py::list output_columns, py::list project_columns, std::vector> py_callbacks) { auto map = std::make_shared( self, std::move(toTensorOperations(operations)), toStringVector(input_columns), - toStringVector(output_columns), toStringVector(project_columns), toDatasetCache(std::move(cc)), + toStringVector(output_columns), toStringVector(project_columns), nullptr, std::vector>(py_callbacks.begin(), py_callbacks.end())); THROW_IF_ERROR(map->ValidateParams()); return map; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc index 52086c0077..a3cba83fd6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc @@ -212,6 +212,11 @@ std::shared_ptr DatasetNode::SetNumWorkers(int32_t num_workers) { return shared_from_this(); } +std::shared_ptr DatasetNode::SetDatasetCache(const std::shared_ptr &cache) { + cache_ = cache; + return shared_from_this(); +} + DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h index 3f9e47a8f5..e95251637b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h @@ -260,6 +260,11 @@ class DatasetNode : public std::enable_shared_from_this { /// \return Shared pointer to the original object std::shared_ptr SetNumWorkers(int32_t num_workers); + /// \brief Setter function for DatasetCache + /// \param[in] cache Shared pointer to DatasetCache + /// \return Shared pointer to the original object + std::shared_ptr SetDatasetCache(const std::shared_ptr &cache); + /// \brief A helper templated function for casting "this" pointer to shared_ptr /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr /// \return A shared_ptr casted to the derived class diff --git a/mindspore/dataset/core/config.py b/mindspore/dataset/core/config.py index 93749b5013..494ca51949 100644 --- a/mindspore/dataset/core/config.py +++ b/mindspore/dataset/core/config.py @@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde __all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers', 'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load', - 'get_callback_timeout', 'set_auto_num_workers', 'get_auto_num_workers'] + 'get_callback_timeout', 'set_auto_num_workers', 'get_auto_num_workers', '_init_device_info'] INT32_MAX = 2147483647 UINT32_MAX = 4294967295 diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 72dfd50ba1..4273727466 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -131,7 +131,7 @@ class Dataset: (default=None). """ - def __init__(self, children=None, num_parallel_workers=None): + def __init__(self, children=None, num_parallel_workers=None, cache=None): # Note: children and parent are internal variables, not recommended for external using. self.children = replace_none(children, []) if isinstance(self.children, tuple): @@ -143,6 +143,7 @@ class Dataset: for child in self.children: child.parent.append(weakref.ref(self)) self.num_parallel_workers = num_parallel_workers + self.cache = cache # todo check the following: self._device_iter = 0 @@ -223,8 +224,25 @@ class Dataset: # Bootstrap on original dataset node will make all iterators share the same process pool self.iterator_bootstrap() ir_node = self.parse(ir_children) + ir_node = self.post_parse(ir_node) return ir_node + def __safe_deepcopy__(self, memodict, exclude=()): + if id(self) in memodict: + return memodict[id(self)] + cls = self.__class__ + new_op = cls.__new__(cls) + memodict[id(self)] = new_op + for arg, value in self.__dict__.items(): + if arg in exclude: + setattr(new_op, arg, value) + else: + try: + setattr(new_op, arg, copy.deepcopy(value, memodict)) + except TypeError: + setattr(new_op, arg, value) + return new_op + def iterator_bootstrap(self): pass @@ -237,21 +255,6 @@ class Dataset: def __add__(self, datasets): return self.concat(datasets) - def get_args(self): - """ - Return attributes (member variables) related to the current class. - - Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'. - - Args: - - Returns: - dict, attributes related to the current class. - """ - args = dict() - args["num_parallel_workers"] = self.num_parallel_workers - return args - def to_json(self, filename=""): """ Serialize a pipeline into JSON string and dump into file if filename is provided. @@ -266,9 +269,8 @@ class Dataset: return json.loads(ir_tree.to_json(filename)) @check_bucket_batch_by_length - def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes, - element_length_function=None, pad_info=None, - pad_to_bucket_boundary=False, drop_remainder=False): + def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes, element_length_function=None, + pad_info=None, pad_to_bucket_boundary=False, drop_remainder=False): """ Bucket elements according to their lengths. Each bucket will be padded and batched when they are full. @@ -335,8 +337,7 @@ class Dataset: ... pad_to_bucket_boundary) """ return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes, - element_length_function, pad_info, - pad_to_bucket_boundary, drop_remainder) + element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder) @check_batch def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, @@ -984,7 +985,7 @@ class Dataset: elif isinstance(datasets, list): datasets = [self] + datasets else: - raise TypeError("Invalid datasets, expected Dataset object or list of Dataset, but got %s!" % (datasets)) + raise TypeError("Invalid datasets, expected Dataset object or list of Dataset, but got %s!" % datasets) return ConcatDataset(datasets) @check_rename @@ -1620,17 +1621,46 @@ class Dataset: def parse(self, children=None): raise NotImplementedError("Dataset has to implement parse method.") + def post_parse(self, ir_node): + if self.cache: + ir_node = ir_node.set_cache_client(self.cache.cache_client) + if self.num_parallel_workers: + ir_node = ir_node.set_num_workers(self.num_parallel_workers) + + return ir_node + class SourceDataset(Dataset): """ Abstract class to represent a source dataset which produces content to the data pipeline. """ + def __init__(self, num_parallel_workers=None, num_samples=None, shuffle=True, num_shards=None, shard_id=None, + cache=None): + super().__init__(num_parallel_workers=num_parallel_workers, cache=cache) + self.num_samples = replace_none(num_samples, 0) + self.num_shards = replace_none(num_shards, 1) + self.shard_id = replace_none(shard_id, 0) + + if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)): + raise TypeError( + "shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or 'Shuffle.FILES'.") + + self.shuffle_flag = 2 # Global shuffle + if not isinstance(shuffle, Shuffle): + if shuffle is None or shuffle: + self.shuffle_flag = 2 # Global shuffle + else: + self.shuffle_flag = 0 # No shuffle + else: + if shuffle == Shuffle.GLOBAL: + self.shuffle_flag = 2 # Global shuffle + elif shuffle == Shuffle.FILES: + self.shuffle_flag = 1 # Files shuffle + def parse(self, children=None): raise NotImplementedError("Dataset has to implement parse method.") - # No need for __init__ since it is the same as the super's init - @staticmethod def _find_files(patterns): """ @@ -1664,10 +1694,12 @@ class SourceDataset(Dataset): raise ValueError("The list of path names matching the patterns is empty.") def is_shuffled(self): - raise NotImplementedError("SourceDataset must implement is_shuffled.") + return self.shuffle_flag > 0 def is_sharded(self): - raise NotImplementedError("SourceDataset must implement is_sharded.") + if self.num_shards is not None: + return self.num_shards > 1 + return False class MappableDataset(SourceDataset): @@ -1678,10 +1710,12 @@ class MappableDataset(SourceDataset): def parse(self, children=None): raise NotImplementedError("Dataset has to implement parse method.") - def __init__(self, num_parallel_workers=None): - # check if all subclasses use this name - super().__init__(num_parallel_workers=num_parallel_workers) - self.sampler = None + def __init__(self, num_parallel_workers=None, sampler=None, num_samples=None, shuffle=None, num_shards=None, + shard_id=None, cache=None): + super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, + num_shards=num_shards, shard_id=shard_id, cache=cache) + self.shuffle_flag = replace_none(shuffle, True) + self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) def add_sampler(self, new_sampler): # note: By adding a sampler, the sampled IDs will flow to new_sampler @@ -1715,19 +1749,10 @@ class MappableDataset(SourceDataset): self.add_sampler(new_sampler) def is_shuffled(self): - raise NotImplementedError("MappableDataset must implement is_shuffled.") + return self.sampler.is_shuffled() def is_sharded(self): - raise NotImplementedError("MappableDataset must implement is_sharded.") - - def _get_sampler_dataset_size(self): - if self.sampler is not None: - if hasattr(self.sampler, 'get_num_samples'): - return self.sampler.get_num_samples() - if hasattr(self.sampler, '__len__'): - return len(self.sampler) - - return None + return self.sampler.is_sharded() @check_split def split(self, sizes, randomize=True): @@ -1831,11 +1856,11 @@ class BucketBatchByLengthDataset(Dataset): The result of applying BucketBatchByLength operator to the input dataset. """ - def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes, - element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder): + def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, + pad_info, pad_to_bucket_boundary, drop_remainder): super().__init__(children=input_dataset) - self.column_names = replace_none(column_names, []) + self.column_names = to_list(column_names) self.bucket_boundaries = replace_none(bucket_boundaries, []) self.bucket_batch_sizes = replace_none(bucket_batch_sizes, []) self.element_length_function = element_length_function @@ -1848,17 +1873,6 @@ class BucketBatchByLengthDataset(Dataset): self.bucket_batch_sizes, self.element_length_function, self.pad_info, self.pad_to_bucket_boundary, self.drop_remainder) - def get_args(self): - args = super().get_args() - args["length_dependent_columns"] = self.column_names - args["bucket_boundaries"] = self.bucket_boundaries - args["bucket_batch_sizes"] = self.bucket_batch_sizes - args["element_length_function"] = self.element_length_function - args["pad_info"] = self.pad_info - args["pad_to_bucket_boundary"] = self.pad_to_bucket_boundary - args["drop_remainder"] = self.drop_remainder - return args - class BatchDataset(Dataset): """ @@ -1896,8 +1910,8 @@ class BatchDataset(Dataset): """ - def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, - per_batch_map=None, input_columns=None, output_columns=None, column_order=None, pad_info=None, + def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, + input_columns=None, output_columns=None, column_order=None, pad_info=None, python_multiprocessing=False): super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers) @@ -1906,18 +1920,18 @@ class BatchDataset(Dataset): BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) - # replace non on input args - input_columns = replace_none(input_columns, []) - output_columns = replace_none(output_columns, []) - column_order = replace_none(column_order, []) # if batch_size is callable, set batch_size to 1 and batch_size_func to that callable function self.batch_size = batch_size if not callable(batch_size) else 1 self.batch_size_func = None if not callable(batch_size) else batch_size + self.drop_remainder = replace_none(drop_remainder, False) + self.per_batch_map = per_batch_map - self.input_columns = input_columns if not isinstance(input_columns, str) else [input_columns] - self.output_columns = output_columns if not isinstance(output_columns, str) else [output_columns] - self.column_order = column_order if not isinstance(column_order, str) else [column_order] + + self.input_columns = to_list(input_columns) + self.output_columns = to_list(output_columns) + self.column_order = to_list(column_order) + self.pad = bool(pad_info is not None) self.pad_info = replace_none(pad_info, dict()) @@ -1926,21 +1940,9 @@ class BatchDataset(Dataset): self.hook = None def parse(self, children=None): - return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, self.pad, - self.input_columns, self.output_columns, - self.column_order, self.batch_size_func, self.per_batch_map, - self.pad_info).SetNumWorkers(self.num_parallel_workers) - - def get_args(self): - args = super().get_args() - args["batch_size"] = self.batch_size - args["drop_remainder"] = self.drop_remainder - args["per_batch_map"] = self.per_batch_map - args["input_columns"] = self.input_columns - args["output_columns"] = self.output_columns - args["column_order"] = self.column_order - args["pad_info"] = self.pad_info - return args + return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, self.pad, self.input_columns, + self.output_columns, self.column_order, self.batch_size_func, self.per_batch_map, + self.pad_info) @staticmethod def _is_ancestor_of_repeat(dataset): @@ -1975,33 +1977,7 @@ class BatchDataset(Dataset): BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) def __deepcopy__(self, memodict): - if id(self) in memodict: - return memodict[id(self)] - cls = self.__class__ - new_op = cls.__new__(cls) - memodict[id(self)] = new_op - new_op.children = copy.deepcopy(self.children, memodict) - new_op.parent = copy.deepcopy(self.parent, memodict) - new_op.num_parallel_workers = self.num_parallel_workers - new_op.batch_size = self.batch_size - new_op.batch_size_func = self.batch_size_func - new_op.drop_remainder = self.drop_remainder - new_op.per_batch_map = self.per_batch_map - new_op.input_columns = copy.deepcopy(self.input_columns, memodict) - new_op.output_columns = copy.deepcopy(self.output_columns, memodict) - new_op.column_order = copy.deepcopy(self.column_order, memodict) - new_op.saved_output_types = self.saved_output_types - new_op.saved_output_shapes = self.saved_output_shapes - new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) - new_op.copy_batch_size(copy.deepcopy(self._batch_size, memodict)) - new_op.dataset_size = self.dataset_size - new_op.pad = self.pad - new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) - new_op.hook = copy.deepcopy(self.hook, memodict) - new_op.pad_info = copy.deepcopy(self.pad_info, memodict) - if hasattr(self, "__total_batch__"): - new_op.__total_batch__ = self.__total_batch__ - return new_op + return self.__safe_deepcopy__(memodict, exclude=("per_batch_map", "batch_size_func", "__transfer_dataset__")) # Iterator bootstrap will be called on iterator construction. # A deep copy of Dataset object is created prior of iterator_bootstrap. @@ -2014,8 +1990,7 @@ class BatchDataset(Dataset): # Construct pool with the callable list # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, - initializer=_pyfunc_worker_init, - initargs=([self.per_batch_map],)) + initializer=_pyfunc_worker_init, initargs=([self.per_batch_map],)) idx = 0 # Wrap per_batch_map into _PythonCallable self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool) @@ -2065,11 +2040,6 @@ class BlockReleasePair: self.disable = False def __deepcopy__(self, memodict): - if id(self) in memodict: - return memodict[id(self)] - memodict[id(self)] = self - # condition variable and callback are the same, but reset the counter - self.reset() return self def reset(self): @@ -2158,12 +2128,6 @@ class SyncWaitDataset(Dataset): def is_sync(self): return True - def get_args(self): - args = super().get_args() - args["condition_name"] = self._condition_name - args["condition_func"] = self._pair.block_func - return args - def update_sync_batch_size(self, batch_size): if isinstance(batch_size, int) and batch_size <= 0: raise ValueError("num_batch need to be greater than 0.") @@ -2191,6 +2155,9 @@ class SyncWaitDataset(Dataset): flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) return flag + def iterator_bootstrap(self): + self._pair.reset() + class ShuffleDataset(Dataset): """ @@ -2215,14 +2182,6 @@ class ShuffleDataset(Dataset): def parse(self, children=None): return cde.ShuffleNode(children[0], self.buffer_size, self.reshuffle_each_epoch) - def get_args(self): - args = super().get_args() - args["buffer_size"] = self.buffer_size - if self.reshuffle_each_epoch is not None: - args["reshuffle_each_epoch"] = self.reshuffle_each_epoch - - return args - def is_shuffled(self): return True @@ -2302,8 +2261,8 @@ class _ExceptHookHandler: def __init__(self): sys.excepthook = self.__handler_exception - def __handler_exception(self, type, value, tb): - logger.error("Uncaught exception: ", exc_info=(type, value, tb)) + def __handler_exception(self, ex_type, value, tb): + logger.error("Uncaught exception: ", exc_info=(ex_type, value, tb)) _mp_pool_exit_preprocess() @@ -2338,53 +2297,30 @@ class MapDataset(Dataset): def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None, column_order=None, num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None): - super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers) - if operations is not None: - if not isinstance(operations, list): - operations = [operations] - elif isinstance(operations, list) and len(operations) > 1: - # wraps adjacent Python operations in a Compose to allow mixing of Python and C++ operations - new_ops, start_ind, end_ind = [], 0, 0 - for i, op in enumerate(operations): - if str(op).find("c_transform") >= 0: - # reset counts - if start_ind != end_ind: - new_ops.append(py_transforms.Compose(operations[start_ind:end_ind])) - new_ops.append(op) - start_ind, end_ind = i + 1, i + 1 - else: - end_ind += 1 - # do additional check in case the last operation is a Python operation - if start_ind != end_ind: - new_ops.append(py_transforms.Compose(operations[start_ind:end_ind])) - operations = new_ops - self.operations = replace_none(operations, []) - if input_columns is not None and not isinstance(input_columns, list): - input_columns = [input_columns] - self.input_columns = replace_none(input_columns, []) - if output_columns is not None and not isinstance(output_columns, list): - output_columns = [output_columns] - self.output_columns = replace_none(output_columns, self.input_columns) - self.cache = cache - self.column_order = column_order + super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers, cache=cache) + self.operations = to_list(operations) + self.operations = py_transforms.Compose.reduce(self.operations) + self.input_columns = to_list(input_columns) + self.output_columns = to_list(output_columns) + self.column_order = replace_none(column_order, []) + # If output_columns were not provided then use input_columns + self.output_columns = self.input_columns if not self.output_columns else self.output_columns + + # todo(crc): move to @check_map if self.input_columns and self.output_columns \ and len(self.input_columns) != len(self.output_columns) \ - and self.column_order is None: + and not self.column_order: raise ValueError("When length of input_columns and output_columns are not equal," " column_order must be specified.") self.python_multiprocessing = python_multiprocessing self.process_pool = None - - if callbacks is not None and not isinstance(callbacks, list): - callbacks = [callbacks] - - self.callbacks = callbacks self.hook = None + self.callbacks = to_list(callbacks) + def parse(self, children=None): - column_order = replace_none(self.column_order, []) operations = [] for op in self.operations: if op and getattr(op, 'parse', None): @@ -2392,48 +2328,12 @@ class MapDataset(Dataset): else: operations.append(op) - cc = self.cache.cache_client if self.cache else None - callbacks = [cb.create_runtime_obj() for cb in self.callbacks] if self.callbacks else [] - return cde.MapNode(children[0], operations, self.input_columns, self.output_columns, column_order, cc, - callbacks).SetNumWorkers(self.num_parallel_workers) - - def get_args(self): - args = super().get_args() - args["input_columns"] = self.input_columns - args["operations"] = self.operations - args["output_columns"] = self.output_columns - args["column_order"] = self.column_order - args["cache"] = self.cache.cache_client if self.cache is not None else None - - if self.callbacks is not None: - args["callbacks"] = [cb.create_runtime_obj() for cb in self.callbacks] - return args + callbacks = [cb.create_runtime_obj() for cb in self.callbacks] + return cde.MapNode(children[0], operations, self.input_columns, self.output_columns, self.column_order, + callbacks) def __deepcopy__(self, memodict): - if id(self) in memodict: - return memodict[id(self)] - cls = self.__class__ - new_op = cls.__new__(cls) - memodict[id(self)] = new_op - new_op.children = copy.deepcopy(self.children, memodict) - new_op.input_columns = copy.deepcopy(self.input_columns, memodict) - new_op.output_columns = copy.deepcopy(self.output_columns, memodict) - new_op.column_order = copy.deepcopy(self.column_order, memodict) - new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) - new_op.parent = copy.deepcopy(self.parent, memodict) - new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) - new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) - new_op.cache = copy.deepcopy(self.cache, memodict) - new_op.hook = copy.deepcopy(self.hook, memodict) - new_op.operations = self.operations - new_op.dataset_size = self.dataset_size - new_op.saved_output_types = self.saved_output_types - new_op.saved_output_shapes = self.saved_output_shapes - - new_op.callbacks = self.callbacks - if hasattr(self, "__total_batch__"): - new_op.__total_batch__ = self.__total_batch__ - return new_op + return self.__safe_deepcopy__(memodict, exclude=("operations", "callbacks", "__transfer_dataset__")) # Iterator bootstrap will be called on iterator construction. # A deep copy of Dataset object is created prior of iterator_bootstrap. @@ -2456,8 +2356,7 @@ class MapDataset(Dataset): # Construct pool with the callable list # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, - initializer=_pyfunc_worker_init, - initargs=(callable_list,)) + initializer=_pyfunc_worker_init, initargs=(callable_list,)) # Pass #2 idx = 0 for op in self.operations: @@ -2496,20 +2395,11 @@ class FilterDataset(Dataset): def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None): super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers) self.predicate = lambda *args: bool(predicate(*args)) - - if input_columns is not None and not isinstance(input_columns, list): - input_columns = [input_columns] - self.input_columns = replace_none(input_columns, []) + self.input_columns = to_list(input_columns) def parse(self, children=None): return cde.FilterNode(children[0], self.predicate, self.input_columns) - def get_args(self): - args = super().get_args() - args["predicate"] = self.predicate - args["input_columns"] = self.input_columns - return args - class RepeatDataset(Dataset): """ @@ -2527,11 +2417,6 @@ class RepeatDataset(Dataset): def parse(self, children=None): return cde.RepeatNode(children[0], self.count) - def get_args(self): - args = super().get_args() - args["count"] = self.count - return args - class SkipDataset(Dataset): """ @@ -2549,11 +2434,6 @@ class SkipDataset(Dataset): def parse(self, children=None): return cde.SkipNode(children[0], self.count) - def get_args(self): - args = super().get_args() - args["count"] = self.count - return args - class TakeDataset(Dataset): """ @@ -2571,11 +2451,6 @@ class TakeDataset(Dataset): def parse(self, children=None): return cde.TakeNode(children[0], self.count) - def get_args(self): - args = super().get_args() - args["count"] = self.count - return args - class ZipDataset(Dataset): """ @@ -2590,7 +2465,6 @@ class ZipDataset(Dataset): def __init__(self, datasets): super().__init__(children=datasets) - self.datasets = datasets def parse(self, children=None): return cde.ZipNode(children) @@ -2598,10 +2472,6 @@ class ZipDataset(Dataset): def is_sync(self): return any([c.is_sync() for c in self.children]) - def get_args(self): - args = super().get_args() - return args - class ConcatDataset(Dataset): """ @@ -2669,7 +2539,7 @@ class ConcatDataset(Dataset): ValueError: If num_shards <=0. """ if not isinstance(sampler, samplers.DistributedSampler): - raise TypeError("The parameter %s of concat must be DistributedSampler!" % (sampler)) + raise TypeError("The parameter %s of concat must be DistributedSampler!" % sampler) if sampler.is_shuffled(): raise ValueError("The parameter shuffle of DistributedSampler must be False!") @@ -2682,14 +2552,14 @@ class ConcatDataset(Dataset): self.dataset_size = None - self._sampler = samplers.select_sampler(None, sampler, None, None, None) + self._sampler = sampler cumulative_samples_nums = 0 for index, child in enumerate(self.children): if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None: - raise ValueError("The parameter NumSamples of %s is not support to be set!" % (child)) + raise ValueError("The parameter NumSamples of %s is not support to be set!" % child) if isinstance(child, BatchDataset): - raise TypeError("The parameter %s of concat must not be BatchDataset!" % (child)) + raise TypeError("The parameter %s of concat must not be BatchDataset!" % child) # if child is mappable and the length is greater than 0 if not self._children_flag_and_nums[index][0] and self._children_flag_and_nums[index][1]: @@ -2711,15 +2581,6 @@ class ConcatDataset(Dataset): cumulative_samples_nums += self.children_sizes_[index] cumulative_samples_nums %= sampler.num_shards - def get_args(self): - args = super().get_args() - - if self._sampler is not None: - args["sampler"] = self._sampler - args["children_flag_and_nums"] = self._children_flag_and_nums - args["children_start_end_index"] = self._children_start_end_index_ - return args - class RenameDataset(Dataset): """ @@ -2733,21 +2594,19 @@ class RenameDataset(Dataset): def __init__(self, input_dataset, input_columns, output_columns): super().__init__(children=input_dataset) - if input_columns is not None and not isinstance(input_columns, list): - input_columns = [input_columns] - if output_columns is not None and not isinstance(output_columns, list): - output_columns = [output_columns] - self.input_column_names = replace_none(input_columns, []) - self.output_column_names = replace_none(output_columns, []) + self.input_column_names = to_list(input_columns) + self.output_column_names = to_list(output_columns) def parse(self, children=None): return cde.RenameNode(children[0], self.input_column_names, self.output_column_names) - def get_args(self): - args = super().get_args() - args["input_columns"] = self.input_column_names - args["output_columns"] = self.output_column_names - return args + +def to_list(items): + if items is None: + return [] + if not isinstance(items, list): + return [items] + return items class ProjectDataset(Dataset): @@ -2757,26 +2616,15 @@ class ProjectDataset(Dataset): Args: input_dataset (Dataset): Input Dataset to be Projected. columns (Union[str, list[str]]): List of names of the columns to project. - prefetch_size (int, optional): Prefetch number of records ahead of the - user's request (default=None). """ - def __init__(self, input_dataset, columns, prefetch_size=None): + def __init__(self, input_dataset, columns): super().__init__(children=input_dataset) - if columns is not None and not isinstance(columns, list): - columns = [columns] - self.columns = replace_none(columns, []) - self.prefetch_size = prefetch_size + self.columns = to_list(columns) def parse(self, children=None): return cde.ProjectNode(children[0], self.columns) - def get_args(self): - args = super().get_args() - args["columns"] = self.columns - args["prefetch_size"] = self.prefetch_size - return args - class _ToDevice: """ @@ -2855,13 +2703,6 @@ class TransferDataset(Dataset): return cde.TransferNode(children[0], self.queue_name, self.device_type, self._send_epoch_end, total_batch, self._create_data_info_queue) - def get_args(self): - args = super().get_args() - args["send_epoch_end"] = self._send_epoch_end - if hasattr(self.children[0], "__total_batch__"): - args["total_batch"] = self.children[0].__total_batch__ - return args - def create_dict_iterator(self, num_epochs=-1, output_numpy=False): raise RuntimeError("TransferDataset is not iterable.") @@ -2909,22 +2750,6 @@ class TransferDataset(Dataset): if self._to_device is not None: self._to_device.release() - def __deepcopy__(self, memodict): - if id(self) in memodict: - return memodict[id(self)] - cls = self.__class__ - new_op = cls.__new__(cls) - memodict[id(self)] = new_op - new_op.children = copy.deepcopy(self.children, memodict) - new_op.parent = copy.deepcopy(self.parent, memodict) - new_op.num_parallel_workers = self.num_parallel_workers - new_op.queue_name = self.queue_name - new_op.device_type = self.device_type - new_op._send_epoch_end = self._send_epoch_end # pylint: disable=W0212 - new_op._create_data_info_queue = self._create_data_info_queue # pylint: disable=W0212 - - return new_op - class RangeDataset(MappableDataset): """ @@ -2945,13 +2770,6 @@ class RangeDataset(MappableDataset): def parse(self, children=None): raise NotImplementedError("Dataset has to implement parse method.") - def get_args(self): - args = super().get_args() - args["start"] = self.start - args["stop"] = self.stop - args["step"] = self.step - return args - def is_shuffled(self): return False @@ -3050,56 +2868,18 @@ class ImageFolderDataset(MappableDataset): """ @check_imagefolderdataset - def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, - shuffle=None, sampler=None, extensions=None, class_indexing=None, - decode=False, num_shards=None, shard_id=None, cache=None): - super().__init__(num_parallel_workers=num_parallel_workers) + def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, + extensions=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, cache=None): + super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, + shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) self.dataset_dir = dataset_dir - self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) - self.num_samples = num_samples - self.shuffle_level = shuffle self.extensions = replace_none(extensions, []) - self.class_indexing = class_indexing + self.class_indexing = replace_none(class_indexing, {}) self.decode = replace_none(decode, False) - self.num_shards = num_shards - self.shard_id = shard_id - self.cache = cache def parse(self, children=None): - if self.cache: - cc = self.cache.cache_client - else: - cc = None - class_indexing = replace_none(self.class_indexing, {}) - return cde.ImageFolderNode(self.dataset_dir, self.decode, self.sampler, self.extensions, - class_indexing, cc).SetNumWorkers(self.num_parallel_workers) - - def get_args(self): - args = super().get_args() - args["dataset_dir"] = self.dataset_dir - args["num_samples"] = self.num_samples - args["sampler"] = self.sampler - args["shuffle"] = self.shuffle_level - args["extensions"] = self.extensions - args["class_indexing"] = self.class_indexing - args["decode"] = self.decode - args["num_shards"] = self.num_shards - args["shard_id"] = self.shard_id - args["cache"] = self.cache.cache_client if self.cache is not None else None - return args - - def is_shuffled(self): - if self.shuffle_level is None: - return True - - return self.shuffle_level or self.sampler.is_shuffled() - - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - - return self.sampler.is_sharded() + return cde.ImageFolderNode(self.dataset_dir, self.decode, self.sampler, self.extensions, self.class_indexing) class MnistDataset(MappableDataset): @@ -3186,51 +2966,17 @@ class MnistDataset(MappableDataset): >>> # Note: In mnist_dataset dataset, each dictionary has keys "image" and "label" """ - def parse(self, children=None): - if self.cache: - cc = self.cache.cache_client - else: - cc = None - - return cde.MnistNode(self.dataset_dir, self.usage, self.sampler, cc).SetNumWorkers(self.num_parallel_workers) - @check_mnist_cifar_dataset - def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, - shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None): - super().__init__(num_parallel_workers=num_parallel_workers) + def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, + num_shards=None, shard_id=None, cache=None): + super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, + shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) self.dataset_dir = dataset_dir self.usage = replace_none(usage, "all") - self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) - self.num_samples = num_samples - self.shuffle_level = shuffle - self.num_shards = num_shards - self.shard_id = shard_id - self.cache = cache - - def get_args(self): - args = super().get_args() - args["dataset_dir"] = self.dataset_dir - args["usage"] = self.usage - args["num_samples"] = self.num_samples - args["shuffle"] = self.shuffle_level - args["sampler"] = self.sampler - args["num_shards"] = self.num_shards - args["shard_id"] = self.shard_id - args["cache"] = self.cache.cache_client if self.cache is not None else None - return args - - def is_shuffled(self): - if self.shuffle_level is None: - return True - - return self.shuffle_level or self.sampler.is_shuffled() - - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - return self.sampler.is_sharded() + def parse(self, children=None): + return cde.MnistNode(self.dataset_dir, self.usage, self.sampler) class MindDataset(MappableDataset): @@ -3268,14 +3014,13 @@ class MindDataset(MappableDataset): def parse(self, children=None): return cde.MindDataNode(self.dataset_file, self.columns_list, self.sampler, self.new_padded_sample, - self.num_padded).SetNumWorkers(self.num_parallel_workers) + self.num_padded) @check_minddataset - def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, - shuffle=None, num_shards=None, shard_id=None, - sampler=None, padded_sample=None, - num_padded=None, num_samples=None): - super().__init__(num_parallel_workers=num_parallel_workers) + def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, shuffle=None, num_shards=None, + shard_id=None, sampler=None, padded_sample=None, num_padded=None, num_samples=None): + super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, + shuffle=shuffle, num_shards=num_shards, shard_id=shard_id) if isinstance(dataset_file, list): self.load_dataset = False else: @@ -3283,20 +3028,17 @@ class MindDataset(MappableDataset): self.dataset_file = dataset_file self.columns_list = replace_none(columns_list, []) self.shuffle_option = shuffle - self.num_shards = num_shards - self.shard_id = shard_id + if shuffle is False: logger.warning("WARN: global shuffle is not used.") if sampler is not None: - if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.SubsetSampler, samplers.PKSampler, - samplers.DistributedSampler, samplers.RandomSampler, - samplers.SequentialSampler)) is False: + if isinstance(sampler, ( + samplers.SubsetRandomSampler, samplers.SubsetSampler, samplers.PKSampler, + samplers.DistributedSampler, + samplers.RandomSampler, samplers.SequentialSampler)) is False: raise ValueError("The sampler is not supported yet.") - self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) - self.num_samples = num_samples - self.padded_sample = padded_sample self.num_padded = replace_none(num_padded, 0) @@ -3308,44 +3050,12 @@ class MindDataset(MappableDataset): else: self.new_padded_sample[k] = v - def get_args(self): - args = super().get_args() - padded_sample = None - if self.padded_sample: - padded_sample = {} - for k, v in self.padded_sample.items(): - if isinstance(v, np.ndarray): - padded_sample[k] = v.tobytes() - else: - padded_sample[k] = v - args["dataset_file"] = self.dataset_file - args["load_dataset"] = self.load_dataset - args["columns_list"] = self.columns_list - args["shuffle_option"] = self.shuffle_option - args["num_samples"] = self.num_samples - args["num_padded"] = self.num_padded - args["padded_sample"] = padded_sample - args["sampler"] = self.sampler - return args - - def is_shuffled(self): - if self.shuffle_option is None: - return True - - return self.shuffle_option or self.sampler.is_shuffled() - - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - - return self.sampler.is_sharded() - def _iter_fn(dataset, num_samples): """ Generator function wrapper for iterable dataset. """ - if num_samples is not None: + if num_samples is not None and num_samples != 0: ds_iter = iter(dataset) for _ in range(num_samples): try: @@ -3364,7 +3074,7 @@ def _generator_fn(generator, num_samples): """ Generator function wrapper for generator function dataset. """ - if num_samples is not None: + if num_samples is not None and num_samples != 0: gen_iter = generator() for _ in range(num_samples): try: @@ -3700,14 +3410,14 @@ class GeneratorDataset(MappableDataset): >>> multi_column_generator_dataset = ds.GeneratorDataset(GeneratorMC, ["col1", "col2"]) >>> >>> # 3) Iterable dataset as iterable input - >>> class MyIterable(): + >>> class MyIterable: ... def __iter__(self): ... return # User implementation >>> # Create iterable_generator_dataset with MyIterable object >>> iterable_generator_dataset = ds.GeneratorDataset(MyIterable(), ["col1"]) >>> >>> # 4) Random accessible dataset as random accessible input - >>> class MyRA(): + >>> class MyRA: ... def __getitem__(self, index): ... return # User implementation >>> # Create ra_generator_dataset with MyRA object @@ -3723,22 +3433,19 @@ class GeneratorDataset(MappableDataset): def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, python_multiprocessing=True): - super().__init__(num_parallel_workers=num_parallel_workers) + super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, + shuffle=shuffle, num_shards=num_shards, shard_id=shard_id) self.source = source - self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) - self.num_samples = num_samples - self.num_shards = num_shards + self.python_multiprocessing = python_multiprocessing - self.num_parallel_workers = num_parallel_workers - if column_names is not None and not isinstance(column_names, list): - column_names = [column_names] - self.column_names = replace_none(column_names, []) + self.column_names = to_list(column_names) if column_types is not None: self.column_types = mstypelist_to_detypelist(column_types) else: self.column_types = [] + self.schema = schema if schema is not None: self.schema = schema @@ -3753,24 +3460,9 @@ class GeneratorDataset(MappableDataset): def __deepcopy__(self, memodict): if id(self) in memodict: return memodict[id(self)] - cls = self.__class__ - new_op = cls.__new__(cls) - memodict[id(self)] = new_op - new_op.children = copy.deepcopy(self.children, memodict) - new_op.parent = copy.deepcopy(self.parent, memodict) - new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict) - new_op.schema = copy.deepcopy(self.schema, memodict) - new_op.column_names = copy.deepcopy(self.column_names, memodict) - new_op.column_types = copy.deepcopy(self.column_types, memodict) - new_op.num_samples = copy.deepcopy(self.num_samples, memodict) - new_op.sampler = copy.deepcopy(self.sampler) - new_op.dataset_size = self.dataset_size - new_op.source_len = self.source_len - new_op.saved_output_types = self.saved_output_types - new_op.saved_output_shapes = self.saved_output_shapes + new_op = self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__")) + sample_fn = None - if hasattr(self, "__total_batch__"): - new_op.__total_batch__ = self.__total_batch__ if new_op.sampler is not None and hasattr(self.source, "__getitem__"): if new_op.num_parallel_workers > 1: sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) @@ -3803,13 +3495,12 @@ class GeneratorDataset(MappableDataset): def parse(self, children=None): if self.schema is None: - return cde.GeneratorNode(self.source, self.column_names, self.column_types, - self.source_len, self.sampler).SetNumWorkers(self.num_parallel_workers) + return cde.GeneratorNode(self.source, self.column_names, self.column_types, self.source_len, + self.sampler) schema = self.schema if isinstance(schema, Schema): schema = self.schema.cpp_schema - return cde.GeneratorNode(self.source, schema, self.source_len, self.sampler).SetNumWorkers( - self.num_parallel_workers) + return cde.GeneratorNode(self.source, schema, self.source_len, self.sampler) class TFRecordDataset(SourceDataset): @@ -3866,110 +3557,26 @@ class TFRecordDataset(SourceDataset): >>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir, schema="./schema.json") """ - def parse(self, children=None): - # set c++ parameters - shuffle_flag = 2 - if not isinstance(self._shuffle, Shuffle): - if self._shuffle: - shuffle_flag = 2 - else: - shuffle_flag = 0 - else: - if self._shuffle == Shuffle.GLOBAL: - shuffle_flag = 2 - elif self._shuffle == Shuffle.FILES: - shuffle_flag = 1 - - schema = self.schema - if isinstance(schema, Schema): - schema = self.schema.cpp_schema - if self.cache: - cc = self.cache.cache_client - else: - cc = None - - num_shards = replace_none(self.num_shards, 1) - shard_id = replace_none(self.shard_id, 0) - num_samples = replace_none(self.num_samples, 0) - - return cde.TFRecordNode(self.dataset_files, schema, self.columns_list, num_samples, - shuffle_flag, - num_shards, shard_id, - self.shard_equal_rows, cc).SetNumWorkers(self.num_parallel_workers) - @check_tfrecorddataset def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None): - super().__init__(num_parallel_workers=num_parallel_workers) + super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, + num_shards=num_shards, shard_id=shard_id, cache=cache) # todo push down to c++ self.dataset_files = self._find_files(dataset_files) self.dataset_files.sort() - if not isinstance(self.dataset_files, list): - self.dataset_files = [self.dataset_files] - self.num_shards = num_shards - self.shard_id = shard_id + self.schema = schema - self._shuffle = shuffle self.columns_list = replace_none(columns_list, []) - self.num_samples = num_samples - self.cache = cache - if self.num_samples is None: - schema_obj = self.schema - if not isinstance(schema_obj, Schema): - schema_obj = Schema(schema_obj) - schema_num_samples = schema_obj.cpp_schema.get_num_rows() - if schema_num_samples != 0: - self.num_samples = schema_num_samples - - if not isinstance(shuffle, (bool, Shuffle)): - raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like" - " 'Shuffle.GLOBAL' or 'Shuffle.FILES'.") - if not isinstance(shuffle, Shuffle): - if shuffle: - self.shuffle_level = Shuffle.GLOBAL - self.shuffle_files = True - else: - self.shuffle_level = None - self.shuffle_files = False - else: - self.shuffle_level = shuffle - self.shuffle_files = True - self.shard_equal_rows = replace_none(shard_equal_rows, False) - def get_args(self): - args = super().get_args() - args["dataset_files"] = self.dataset_files - if self.schema is not None: - if isinstance(self.schema, Schema): - self.schema.datasetType = 'TF' - if self.num_samples is not None: - self.schema.num_rows = self.num_samples - args["schema_json_string"] = self.schema.to_json() - else: - args["schema_file_path"] = self.schema - args["schema"] = self.schema - args["columns_list"] = self.columns_list - args["num_samples"] = self.num_samples - if self.shuffle_files is not None: - args["shuffle_files"] = self.shuffle_files - args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) - args["shuffle"] = self._shuffle - args["num_shards"] = self.num_shards - args["shard_id"] = self.shard_id - args["shard_equal_rows"] = self.shard_equal_rows - args["cache"] = self.cache.cache_client if self.cache is not None else None - args["sampler"] = self.sampler - return args - - def is_shuffled(self): - return self.shuffle_files + if self.schema is not None and (self.num_samples is None or self.num_samples == 0): + self.num_samples = Schema.get_num_rows(self.schema) - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - - return False + def parse(self, children=None): + schema = self.schema.cpp_schema if isinstance(self.schema, Schema) else self.schema + return cde.TFRecordNode(self.dataset_files, schema, self.columns_list, self.num_samples, self.shuffle_flag, + self.num_shards, self.shard_id, self.shard_equal_rows) class ManifestDataset(MappableDataset): @@ -4050,49 +3657,19 @@ class ManifestDataset(MappableDataset): """ - def parse(self, children=None): - if self.cache: - cc = self.cache.cache_client - else: - cc = None - class_indexing = replace_none(self.class_indexing, {}) - return cde.ManifestNode(self.dataset_file, self.usage, self.sampler, class_indexing, - self.decode, cc).SetNumWorkers(self.num_parallel_workers) - @check_manifestdataset - def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None, - shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, - cache=None): - super().__init__(num_parallel_workers=num_parallel_workers) + def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None, shuffle=None, + sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None, cache=None): + super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, + shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) self.dataset_file = dataset_file - self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) - - if class_indexing is not None and not isinstance(class_indexing, dict): - raise RuntimeError("class_indexing must be a dictionary.") - - self.num_samples = num_samples - self.class_indexing = class_indexing self.decode = replace_none(decode, False) self.usage = replace_none(usage, "train") - self.shuffle_level = shuffle - self.num_shards = num_shards - self.shard_id = shard_id - self.cache = cache + self.class_indexing = replace_none(class_indexing, {}) - def get_args(self): - args = super().get_args() - args["dataset_file"] = self.dataset_file - args["usage"] = self.usage - args["num_samples"] = self.num_samples - args["shuffle"] = self.shuffle_level - args["sampler"] = self.sampler - args["class_indexing"] = self.class_indexing - args["decode"] = self.decode - args["num_shards"] = self.num_shards - args["shard_id"] = self.shard_id - args["cache"] = self.cache.cache_client if self.cache is not None else None - return args + def parse(self, children=None): + return cde.ManifestNode(self.dataset_file, self.usage, self.sampler, self.class_indexing, self.decode) def get_class_indexing(self): """ @@ -4101,7 +3678,7 @@ class ManifestDataset(MappableDataset): Returns: dict, a str-to-int mapping from label name to index. """ - if self.class_indexing is None: + if self.class_indexing is None or not self.class_indexing: if self._class_indexing is None: runtime_getter = self._init_tree_getters() self._class_indexing = runtime_getter[0].GetClassIndexing() @@ -4110,18 +3687,6 @@ class ManifestDataset(MappableDataset): self.class_indexing[pair[0]] = pair[1][0] return self.class_indexing - def is_shuffled(self): - if self.shuffle_level is None: - return True - - return self.shuffle_level or self.sampler.is_shuffled() - - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - - return self.sampler.is_sharded() - class Cifar10Dataset(MappableDataset): """ @@ -4213,50 +3778,17 @@ class Cifar10Dataset(MappableDataset): >>> # In CIFAR10 dataset, each dictionary has keys "image" and "label" """ - def parse(self, children=None): - if self.cache: - cc = self.cache.cache_client - else: - cc = None - return cde.Cifar10Node(self.dataset_dir, self.usage, self.sampler, cc).SetNumWorkers(self.num_parallel_workers) - @check_mnist_cifar_dataset - def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, - shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None): - super().__init__(num_parallel_workers=num_parallel_workers) + def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, + num_shards=None, shard_id=None, cache=None): + super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, + shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) self.dataset_dir = dataset_dir self.usage = replace_none(usage, "all") - self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) - self.num_samples = num_samples - self.num_shards = num_shards - self.shard_id = shard_id - self.shuffle_level = shuffle - self.cache = cache - - def get_args(self): - args = super().get_args() - args["dataset_dir"] = self.dataset_dir - args["usage"] = self.usage - args["num_samples"] = self.num_samples - args["sampler"] = self.sampler - args["num_shards"] = self.num_shards - args["shard_id"] = self.shard_id - args["shuffle"] = self.shuffle_level - args["cache"] = self.cache.cache_client if self.cache is not None else None - return args - - def is_shuffled(self): - if self.shuffle_level is None: - return True - - return self.shuffle_level or self.sampler.is_shuffled() - - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - return self.sampler.is_sharded() + def parse(self, children=None): + return cde.Cifar10Node(self.dataset_dir, self.usage, self.sampler) class Cifar100Dataset(MappableDataset): @@ -4348,50 +3880,17 @@ class Cifar100Dataset(MappableDataset): >>> # In CIFAR100 dataset, each dictionary has 3 keys: "image", "fine_label" and "coarse_label" """ - def parse(self, children=None): - if self.cache: - cc = self.cache.cache_client - else: - cc = None - return cde.Cifar100Node(self.dataset_dir, self.usage, self.sampler, cc).SetNumWorkers(self.num_parallel_workers) - @check_mnist_cifar_dataset - def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, - shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None): - super().__init__(num_parallel_workers=num_parallel_workers) + def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, + num_shards=None, shard_id=None, cache=None): + super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, + shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) self.dataset_dir = dataset_dir self.usage = replace_none(usage, "all") - self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) - self.num_samples = num_samples - self.num_shards = num_shards - self.shard_id = shard_id - self.shuffle_level = shuffle - self.cache = cache - - def get_args(self): - args = super().get_args() - args["dataset_dir"] = self.dataset_dir - args["usage"] = self.usage - args["num_samples"] = self.num_samples - args["sampler"] = self.sampler - args["num_shards"] = self.num_shards - args["shard_id"] = self.shard_id - args["shuffle"] = self.shuffle_level - args["cache"] = self.cache.cache_client if self.cache is not None else None - return args - - def is_shuffled(self): - if self.shuffle_level is None: - return True - return self.shuffle_level or self.sampler.is_shuffled() - - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - - return self.sampler.is_sharded() + def parse(self, children=None): + return cde.Cifar100Node(self.dataset_dir, self.usage, self.sampler) class RandomDataset(SourceDataset): @@ -4417,71 +3916,20 @@ class RandomDataset(SourceDataset): argument can only be specified when num_shards is also specified. """ - def parse(self, children=None): - schema = self.schema - if isinstance(schema, Schema): - schema = self.schema.cpp_schema - if self.cache: - cc = self.cache.cache_client - else: - cc = None - return cde.RandomNode(self.total_rows, schema, self.columns_list, cc).SetNumWorkers( - self.num_parallel_workers) - @check_random_dataset def __init__(self, total_rows=None, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, cache=None, shuffle=None, num_shards=None, shard_id=None): - super().__init__(num_parallel_workers=num_parallel_workers) + super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, + num_shards=num_shards, shard_id=shard_id, cache=cache) + self.total_rows = total_rows + if schema is not None: + self.total_rows = replace_none(total_rows, Schema.get_num_rows(schema)) self.schema = schema self.columns_list = replace_none(columns_list, []) - self.num_samples = num_samples - self.total_rows = total_rows - self.cache = cache - if self.total_rows is None: - schema_obj = self.schema - if not isinstance(schema_obj, Schema): - schema_obj = Schema(schema_obj) - schema_total_rows = schema_obj.cpp_schema.get_num_rows() - if schema_total_rows != 0: - self.total_rows = schema_total_rows - self.total_rows = replace_none(self.total_rows, 0) - - self.num_shards = replace_none(num_shards, 1) - self.shard_id = replace_none(shard_id, 0) - self.shuffle_level = replace_none(shuffle, False) - - self.num_samples = num_samples - - def get_args(self): - args = super().get_args() - if self.schema is not None: - if isinstance(self.schema, Schema): - self.schema.datasetType = 'Random' - if self.total_rows is not None: - self.schema.num_rows = self.total_rows - args["schema_json_string"] = self.schema.to_json() - else: - args["schema_file_path"] = self.schema - args["schema"] = self.schema - args["columns_list"] = self.columns_list - args["num_samples"] = self.num_samples - args["total_rows"] = self.total_rows - args["cache"] = self.cache.cache_client if self.cache is not None else None - args["sampler"] = self.sampler - return args - - def is_shuffled(self): - if self.shuffle_level is None: - return True - - return self.shuffle_level or self.sampler.is_shuffled() - - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - - return self.sampler.is_sharded() + def parse(self, children=None): + schema = self.schema.cpp_schema if isinstance(self.schema, Schema) else self.schema + return cde.RandomNode(self.total_rows, schema, self.columns_list) class Schema: @@ -4586,6 +4034,13 @@ class Schema: def __str__(self): return self.to_json() + @staticmethod + def get_num_rows(schema): + schema_obj = schema + if not isinstance(schema_obj, Schema): + schema_obj = Schema(schema_obj) + return schema_obj.cpp_schema.get_num_rows() + class VOCDataset(MappableDataset): """ @@ -4702,46 +4157,20 @@ class VOCDataset(MappableDataset): >>> # In VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation" """ - def parse(self, children=None): - if self.cache: - cc = self.cache.cache_client - else: - cc = None - class_indexing = replace_none(self.class_indexing, {}) - return cde.VOCNode(self.dataset_dir, self.task, self.usage, class_indexing, self.decode, - self.sampler, cc).SetNumWorkers(self.num_parallel_workers) - @check_vocdataset def __init__(self, dataset_dir, task="Segmentation", usage="train", class_indexing=None, num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None): - super().__init__(num_parallel_workers=num_parallel_workers) + super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, + shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) self.dataset_dir = dataset_dir self.task = replace_none(task, "Segmentation") self.usage = replace_none(usage, "train") - self.class_indexing = class_indexing - self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) - self.num_samples = num_samples + self.class_indexing = replace_none(class_indexing, {}) self.decode = replace_none(decode, False) - self.shuffle_level = shuffle - self.num_shards = num_shards - self.shard_id = shard_id - self.cache = cache - def get_args(self): - args = super().get_args() - args["dataset_dir"] = self.dataset_dir - args["task"] = self.task - args["usage"] = self.usage - args["class_indexing"] = self.class_indexing - args["num_samples"] = self.num_samples - args["sampler"] = self.sampler - args["decode"] = self.decode - args["shuffle"] = self.shuffle_level - args["num_shards"] = self.num_shards - args["shard_id"] = self.shard_id - args["cache"] = self.cache.cache_client if self.cache is not None else None - return args + def parse(self, children=None): + return cde.VOCNode(self.dataset_dir, self.task, self.usage, self.class_indexing, self.decode, self.sampler) def get_class_indexing(self): """ @@ -4752,7 +4181,7 @@ class VOCDataset(MappableDataset): """ if self.task != "Detection": raise NotImplementedError("Only 'Detection' support get_class_indexing.") - if self.class_indexing is None: + if self.class_indexing is None or not self.class_indexing: if self._class_indexing is None: runtime_getter = self._init_tree_getters() self._class_indexing = runtime_getter[0].GetClassIndexing() @@ -4761,18 +4190,6 @@ class VOCDataset(MappableDataset): self.class_indexing[pair[0]] = pair[1][0] return self.class_indexing - def is_shuffled(self): - if self.shuffle_level is None: - return True - - return self.shuffle_level or self.sampler.is_shuffled() - - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - - return self.sampler.is_sharded() - class CocoDataset(MappableDataset): """ @@ -4892,44 +4309,18 @@ class CocoDataset(MappableDataset): >>> # In COCO dataset, each dictionary has keys "image" and "annotation" """ - def parse(self, children=None): - if self.cache: - cc = self.cache.cache_client - else: - cc = None - return cde.CocoNode(self.dataset_dir, self.annotation_file, self.task, self.decode, - self.sampler, cc).SetNumWorkers(self.num_parallel_workers) - @check_cocodataset def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None): - super().__init__(num_parallel_workers=num_parallel_workers) + super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, + shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) self.dataset_dir = dataset_dir self.annotation_file = annotation_file self.task = replace_none(task, "Detection") - self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) - self.num_samples = num_samples self.decode = replace_none(decode, False) - self.shuffle_level = shuffle - self.num_shards = num_shards - self.shard_id = shard_id - self.cache = cache - self.dataset_dir = dataset_dir - self.annotation_file = annotation_file - def get_args(self): - args = super().get_args() - args["dataset_dir"] = self.dataset_dir - args["annotation_file"] = self.annotation_file - args["task"] = self.task - args["num_samples"] = self.num_samples - args["sampler"] = self.sampler - args["decode"] = self.decode - args["shuffle"] = self.shuffle_level - args["num_shards"] = self.num_shards - args["shard_id"] = self.shard_id - args["cache"] = self.cache.cache_client if self.cache is not None else None - return args + def parse(self, children=None): + return cde.CocoNode(self.dataset_dir, self.annotation_file, self.task, self.decode, self.sampler) def get_class_indexing(self): """ @@ -4945,18 +4336,6 @@ class CocoDataset(MappableDataset): self._class_indexing = dict(runtime_getter[0].GetClassIndexing()) return self._class_indexing - def is_shuffled(self): - if self.shuffle_level is None: - return True - - return self.shuffle_level or self.sampler.is_shuffled() - - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - - return self.sampler.is_sharded() - class CelebADataset(MappableDataset): """ @@ -5018,60 +4397,23 @@ class CelebADataset(MappableDataset): >>> dataset = ds.CelebADataset(dataset_dir=celeba_dataset_dir, usage='train') """ - def parse(self, children=None): - if self.usage != "all": - dir = os.path.realpath(self.dataset_dir) - partition_file = os.path.join(dir, "list_eval_partition.txt") - if os.path.exists(partition_file) is False: - raise RuntimeError("Partition file can not be found when usage is not 'all'.") - if self.cache: - cc = self.cache.cache_client - else: - cc = None - return cde.CelebANode(self.dataset_dir, self.usage, self.sampler, self.decode, self.extensions, cc). \ - SetNumWorkers(self.num_parallel_workers) - @check_celebadataset def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False, extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None): - super().__init__(num_parallel_workers=num_parallel_workers) + super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, + shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) self.dataset_dir = dataset_dir - self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) - self.num_parallel_workers = num_parallel_workers self.decode = replace_none(decode, False) self.extensions = replace_none(extensions, []) - self.num_samples = num_samples self.usage = replace_none(usage, "all") - self.num_shards = num_shards - self.shard_id = shard_id - self.shuffle_level = shuffle - self.cache = cache - - def get_args(self): - args = super().get_args() - args["dataset_dir"] = self.dataset_dir - args["sampler"] = self.sampler - args["shuffle"] = self.shuffle_level - args["decode"] = self.decode - args["extensions"] = self.extensions - args["num_samples"] = self.num_samples - args["usage"] = self.usage - args["num_shards"] = self.num_shards - args["shard_id"] = self.shard_id - args["cache"] = self.cache.cache_client if self.cache is not None else None - return args - - def is_shuffled(self): - if self.shuffle_level is None: - return True - - return self.shuffle_level or self.sampler.is_shuffled() - - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - return self.sampler.is_sharded() + def parse(self, children=None): + if self.usage != "all": + dir = os.path.realpath(self.dataset_dir) + partition_file = os.path.join(dir, "list_eval_partition.txt") + if os.path.exists(partition_file) is False: + raise RuntimeError("Partition file can not be found when usage is not 'all'.") + return cde.CelebANode(self.dataset_dir, self.usage, self.sampler, self.decode, self.extensions) class CLUEDataset(SourceDataset): @@ -5130,199 +4472,93 @@ class CLUEDataset(SourceDataset): >>> dataset = ds.CLUEDataset(dataset_files=clue_dataset_dir, task='AFQMC', usage='train') """ - def parse(self, children=None): - # C default values - shuffle_flag = 2 - if not isinstance(self._shuffle, Shuffle): - if self._shuffle: - shuffle_flag = 2 - else: - shuffle_flag = 0 - else: - if self._shuffle == Shuffle.GLOBAL: - shuffle_flag = 2 - elif self._shuffle == Shuffle.FILES: - shuffle_flag = 1 - if self.cache: - cc = self.cache.cache_client - else: - cc = None - return cde.CLUENode(self.dataset_files, self.task, self.usage, self.num_samples, shuffle_flag, - self.num_shards, - self.shard_id, cc).SetNumWorkers(self.num_parallel_workers) - @check_cluedataset - def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None, - num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): - super().__init__(num_parallel_workers=num_parallel_workers) + def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None, num_parallel_workers=None, + shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): + super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, + num_shards=num_shards, shard_id=shard_id, cache=cache) self.dataset_files = self._find_files(dataset_files) - if not isinstance(self.dataset_files, list): - self.dataset_files = [self.dataset_files] - self.num_samples = replace_none(num_samples, 0) + self.task_dict = { 'AFQMC': { 'train': { - 'sentence1': 'sentence1', - 'sentence2': 'sentence2', - 'label': 'label' + 'sentence1': 'sentence1', 'sentence2': 'sentence2', 'label': 'label' }, 'test': { - 'id': 'id', - 'sentence1': 'sentence1', - 'sentence2': 'sentence2' + 'id': 'id', 'sentence1': 'sentence1', 'sentence2': 'sentence2' }, 'eval': { - 'sentence1': 'sentence1', - 'sentence2': 'sentence2', - 'label': 'label' + 'sentence1': 'sentence1', 'sentence2': 'sentence2', 'label': 'label' } }, 'CMNLI': { 'train': { - 'sentence1': 'sentence1', - 'sentence2': 'sentence2', - 'label': 'label' + 'sentence1': 'sentence1', 'sentence2': 'sentence2', 'label': 'label' }, 'test': { - 'id': 'id', - 'sentence1': 'sentence1', - 'sentence2': 'sentence2' + 'id': 'id', 'sentence1': 'sentence1', 'sentence2': 'sentence2' }, 'eval': { - 'sentence1': 'sentence1', - 'sentence2': 'sentence2', - 'label': 'label' + 'sentence1': 'sentence1', 'sentence2': 'sentence2', 'label': 'label' } }, 'CSL': { 'train': { - 'id': 'id', - 'abst': 'abst', - 'keyword': 'keyword', - 'label': 'label' + 'id': 'id', 'abst': 'abst', 'keyword': 'keyword', 'label': 'label' }, 'test': { - 'id': 'id', - 'abst': 'abst', - 'keyword': 'keyword' + 'id': 'id', 'abst': 'abst', 'keyword': 'keyword' }, 'eval': { - 'id': 'id', - 'abst': 'abst', - 'keyword': 'keyword', - 'label': 'label' + 'id': 'id', 'abst': 'abst', 'keyword': 'keyword', 'label': 'label' } }, 'IFLYTEK': { 'train': { - 'label': 'label', - 'label_des': 'label_des', - 'sentence': 'sentence' + 'label': 'label', 'label_des': 'label_des', 'sentence': 'sentence' }, 'test': { - 'id': 'id', - 'sentence': 'sentence', + 'id': 'id', 'sentence': 'sentence', }, 'eval': { - 'label': 'label', - 'label_des': 'label_des', - 'sentence': 'sentence' + 'label': 'label', 'label_des': 'label_des', 'sentence': 'sentence' } }, 'TNEWS': { 'train': { - 'label': 'label', - 'label_desc': 'label_desc', - 'sentence': 'sentence', - 'keywords': 'keywords' + 'label': 'label', 'label_desc': 'label_desc', 'sentence': 'sentence', 'keywords': 'keywords' }, 'test': { - 'id': 'id', - 'sentence': 'sentence', - 'keywords': 'keywords' + 'id': 'id', 'sentence': 'sentence', 'keywords': 'keywords' }, 'eval': { - 'label': 'label', - 'label_desc': 'label_desc', - 'sentence': 'sentence', - 'keywords': 'keywords' + 'label': 'label', 'label_desc': 'label_desc', 'sentence': 'sentence', 'keywords': 'keywords' } }, 'WSC': { 'train': { - 'span1_index': 'target/span1_index', - 'span2_index': 'target/span2_index', - 'span1_text': 'target/span1_text', - 'span2_text': 'target/span2_text', - 'idx': 'idx', - 'label': 'label', - 'text': 'text' + 'span1_index': 'target/span1_index', 'span2_index': 'target/span2_index', + 'span1_text': 'target/span1_text', 'span2_text': 'target/span2_text', 'idx': 'idx', + 'label': 'label', 'text': 'text' }, 'test': { - 'span1_index': 'target/span1_index', - 'span2_index': 'target/span2_index', - 'span1_text': 'target/span1_text', - 'span2_text': 'target/span2_text', - 'idx': 'idx', - 'text': 'text' + 'span1_index': 'target/span1_index', 'span2_index': 'target/span2_index', + 'span1_text': 'target/span1_text', 'span2_text': 'target/span2_text', 'idx': 'idx', 'text': 'text' }, 'eval': { - 'span1_index': 'target/span1_index', - 'span2_index': 'target/span2_index', - 'span1_text': 'target/span1_text', - 'span2_text': 'target/span2_text', - 'idx': 'idx', - 'label': 'label', - 'text': 'text' + 'span1_index': 'target/span1_index', 'span2_index': 'target/span2_index', + 'span1_text': 'target/span1_text', 'span2_text': 'target/span2_text', 'idx': 'idx', + 'label': 'label', 'text': 'text' } } } self.usage = replace_none(usage, 'train') self.cols_to_keyword = self.task_dict[task][self.usage] self.task = replace_none(task, 'AFQMC') - self._shuffle = shuffle - if not isinstance(shuffle, (bool, Shuffle)): - raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like" - " 'Shuffle.GLOBAL' or 'Shuffle.FILES'.") - # To be removed later - if not isinstance(shuffle, Shuffle): - if shuffle: - self.shuffle_level = Shuffle.GLOBAL - self.shuffle_files = True - else: - self.shuffle_level = None - self.shuffle_files = False - else: - self.shuffle_level = shuffle - self.shuffle_files = True - - self.num_shards = replace_none(num_shards, 1) - self.shard_id = replace_none(shard_id, 0) - self.cache = cache - def get_args(self): - args = super().get_args() - args["dataset_files"] = self.dataset_files - args["num_samples"] = self.num_samples - if self.shuffle_files is not None: - args["shuffle_files"] = self.shuffle_files - args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) - args["shuffle"] = self.shuffle_level - args["num_shards"] = self.num_shards - args["shard_id"] = self.shard_id - args["cols_to_keyword"] = self.cols_to_keyword - args["sampler"] = self.sampler - args["cache"] = self.cache.cache_client if self.cache is not None else None - return args - - def is_shuffled(self): - return self.shuffle_files - - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - - return False + def parse(self, children=None): + return cde.CLUENode(self.dataset_files, self.task, self.usage, self.num_samples, self.shuffle_flag, + self.num_shards, self.shard_id) class CSVDataset(SourceDataset): @@ -5364,76 +4600,20 @@ class CSVDataset(SourceDataset): >>> dataset = ds.CSVDataset(dataset_files=csv_dataset_dir, column_names=['col1', 'col2', 'col3', 'col4']) """ - def parse(self, children=None): - if self.cache: - cc = self.cache.cache_client - else: - cc = None - return cde.CSVNode(self.dataset_files, self.field_delim, self.column_defaults, self.column_names, - self.num_samples, - self.shuffle_flag, self.num_shards, - self.shard_id, cc).SetNumWorkers(self.num_parallel_workers) - @check_csvdataset def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): - super().__init__(num_parallel_workers=num_parallel_workers) + super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, + num_shards=num_shards, shard_id=shard_id, cache=cache) self.dataset_files = self._find_files(dataset_files) self.dataset_files.sort() self.field_delim = replace_none(field_delim, ',') self.column_defaults = replace_none(column_defaults, []) self.column_names = replace_none(column_names, []) - self.num_samples = num_samples - - if not isinstance(shuffle, (bool, Shuffle)): - raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like" - " 'Shuffle.GLOBAL' or 'Shuffle.FILES'.") - self.shuffle_flag = 2 - if not isinstance(shuffle, Shuffle): - if shuffle: - self.shuffle_flag = 2 - self.shuffle_files = True - else: - self.shuffle_flag = 0 - self.shuffle_files = False - else: - if shuffle == Shuffle.GLOBAL: - self.shuffle_flag = 2 - elif shuffle == Shuffle.FILES: - self.shuffle_flag = 1 - self.shuffle_files = True - - self.cache = cache - - self.num_shards = replace_none(num_shards, 1) - self.shard_id = replace_none(shard_id, 0) - self.num_samples = replace_none(num_samples, 0) - - def get_args(self): - args = super().get_args() - args["dataset_files"] = self.dataset_files - args['field_delim'] = self.field_delim - args['column_defaults'] = self.column_defaults - args['column_names'] = self.column_names - args["num_samples"] = self.num_samples - if self.shuffle_files is not None: - args["shuffle_files"] = self.shuffle_files - args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) - args["shuffle"] = self.shuffle_level - args["num_shards"] = self.num_shards - args["shard_id"] = self.shard_id - args["sampler"] = self.sampler - args["cache"] = self.cache.cache_client if self.cache is not None else None - return args - - def is_shuffled(self): - return self.shuffle_files - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - - return False + def parse(self, children=None): + return cde.CSVNode(self.dataset_files, self.field_delim, self.column_defaults, self.column_names, + self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id) class TextFileDataset(SourceDataset): @@ -5469,76 +4649,16 @@ class TextFileDataset(SourceDataset): >>> dataset = ds.TextFileDataset(dataset_files=text_file_dataset_dir) """ - def parse(self, children=None): - shuffle_flag = 2 - if not isinstance(self._shuffle, Shuffle): - if self._shuffle: - shuffle_flag = 2 - else: - shuffle_flag = 0 - else: - if self._shuffle == Shuffle.GLOBAL: - shuffle_flag = 2 - elif self._shuffle == Shuffle.FILES: - shuffle_flag = 1 - if self.cache: - cc = self.cache.cache_client - else: - cc = None - - return cde.TextFileNode(self.dataset_files, self.num_samples, shuffle_flag, self.num_shards, - self.shard_id, cc).SetNumWorkers(self.num_parallel_workers) - @check_textfiledataset - def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None, - shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None): - super().__init__(num_parallel_workers=num_parallel_workers) - self._shuffle = shuffle + def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, + num_shards=None, shard_id=None, cache=None): + super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, + num_shards=num_shards, shard_id=shard_id, cache=cache) self.dataset_files = self._find_files(dataset_files) self.dataset_files.sort() - self.num_samples = replace_none(num_samples, 0) - if not isinstance(shuffle, (bool, Shuffle)): - raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like" - " 'Shuffle.GLOBAL' or 'Shuffle.FILES'.") - if not isinstance(shuffle, Shuffle): - if shuffle: - self.shuffle_level = Shuffle.GLOBAL - self.shuffle_files = True - else: - self.shuffle_level = None - self.shuffle_files = False - else: - self.shuffle_level = shuffle - self.shuffle_files = True - - self.num_shards = replace_none(num_shards, 1) - self.shard_id = replace_none(shard_id, 0) - - self.cache = cache - - def get_args(self): - args = super().get_args() - args["dataset_files"] = self.dataset_files - args["num_samples"] = self.num_samples - if self.shuffle_files is not None: - args["shuffle_files"] = self.shuffle_files - args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) - args["shuffle"] = self.shuffle_level - args["num_shards"] = self.num_shards - args["shard_id"] = self.shard_id - args["sampler"] = self.sampler - args["cache"] = self.cache.cache_client if self.cache is not None else None - return args - - def is_shuffled(self): - return self.shuffle_files - - def is_sharded(self): - if self.num_shards is not None: - return self.num_shards > 1 - - return False + def parse(self, children=None): + return cde.TextFileNode(self.dataset_files, self.num_samples, self.shuffle_flag, self.num_shards, self.shard_id) class _NumpySlicesDataset: @@ -5679,8 +4799,8 @@ class NumpySlicesDataset(GeneratorDataset): """ @check_numpyslicesdataset - def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, - sampler=None, num_shards=None, shard_id=None): + def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, + num_shards=None, shard_id=None): dataset = _NumpySlicesDataset(data, column_names) super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples, num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, @@ -5728,8 +4848,6 @@ class PaddedDataset(GeneratorDataset): @check_paddeddataset def __init__(self, padded_samples): dataset = _PaddedDataset(padded_samples) - super().__init__(dataset, column_names=dataset.column_names, - num_shards=None, - shard_id=None, shuffle=False) + super().__init__(dataset, column_names=dataset.column_names, num_shards=None, shard_id=None, shuffle=False) self._dataset_size = len(dataset.padded_samples) self.padded_samples = padded_samples diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index dec448b92e..ec5f8780a9 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -83,7 +83,7 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): shuffle = True return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) # If shuffle is not specified, sharding disabled, use random sampler - if num_samples is not None: + if num_samples is not None and num_samples != 0: return RandomSampler(replacement=True, num_samples=num_samples) return RandomSampler(num_samples=num_samples) if shuffle is True: diff --git a/mindspore/dataset/transforms/py_transforms.py b/mindspore/dataset/transforms/py_transforms.py index 292e3c1613..44cad58ded 100644 --- a/mindspore/dataset/transforms/py_transforms.py +++ b/mindspore/dataset/transforms/py_transforms.py @@ -123,6 +123,35 @@ class Compose: """ return util.compose(self.transforms, *args) + @staticmethod + def reduce(operations): + """ + Wraps adjacent Python operations in a Compose to allow mixing of Python and C++ operations + Args: + operations (list): list of tensor operations + + Returns: + list, the reduced list of operations + """ + # + if len(operations) == 1: + return operations + + new_ops, start_ind, end_ind = [], 0, 0 + for i, op in enumerate(operations): + if str(op).find("c_transform") >= 0: + # reset counts + if start_ind != end_ind: + new_ops.append(Compose(operations[start_ind:end_ind])) + new_ops.append(op) + start_ind, end_ind = i + 1, i + 1 + else: + end_ind += 1 + # do additional check in case the last operation is a Python operation + if start_ind != end_ind: + new_ops.append(Compose(operations[start_ind:end_ind])) + return new_ops + class RandomApply: """