!11971 Python Datasets cleanup

From: @hfarahat
Reviewed-by: 
Signed-off-by:
pull/11971/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 8e592d9673

@ -82,10 +82,15 @@ namespace dataset {
PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) { PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
(void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset") (void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset")
.def("SetNumWorkers", .def("set_num_workers",
[](std::shared_ptr<DatasetNode> self, std::optional<int32_t> num_workers) { [](std::shared_ptr<DatasetNode> self, std::optional<int32_t> num_workers) {
return num_workers ? self->SetNumWorkers(*num_workers) : self; return num_workers ? self->SetNumWorkers(*num_workers) : self;
}) })
.def("set_cache_client",
[](std::shared_ptr<DatasetNode> self) {
std::shared_ptr<DatasetCache> dc = nullptr;
return self->SetDatasetCache(dc);
})
.def( .def(
"Zip", "Zip",
[](std::shared_ptr<DatasetNode> self, py::list datasets) { [](std::shared_ptr<DatasetNode> self, py::list datasets) {
@ -109,10 +114,9 @@ PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
(void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode", (void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode",
"to create a CelebANode") "to create a CelebANode")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, bool decode, .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, bool decode,
py::list extensions, std::shared_ptr<CacheClient> cc) { py::list extensions) {
auto celebA = auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode,
std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode, toStringSet(extensions), nullptr);
toStringSet(extensions), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(celebA->ValidateParams()); THROW_IF_ERROR(celebA->ValidateParams());
return celebA; return celebA;
})); }));
@ -121,10 +125,8 @@ PYBIND_REGISTER(CelebANode, 2, ([](const py::module *m) {
PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) { PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
(void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node", (void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node",
"to create a Cifar10Node") "to create a Cifar10Node")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
std::shared_ptr<CacheClient> cc) { auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler),
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(cifar10->ValidateParams()); THROW_IF_ERROR(cifar10->ValidateParams());
return cifar10; return cifar10;
})); }));
@ -133,23 +135,22 @@ PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) {
PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) { PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) {
(void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node", (void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node",
"to create a Cifar100Node") "to create a Cifar100Node")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
std::shared_ptr<CacheClient> cc) { auto cifar100 =
auto cifar100 = std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler), std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(cifar100->ValidateParams()); THROW_IF_ERROR(cifar100->ValidateParams());
return cifar100; return cifar100;
})); }));
})); }));
PYBIND_REGISTER( PYBIND_REGISTER(CLUENode, 2, ([](const py::module *m) {
CLUENode, 2, ([](const py::module *m) { (void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode",
(void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode", "to create a CLUENode") "to create a CLUENode")
.def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, int32_t shuffle, .def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples,
int32_t num_shards, int32_t shard_id, std::shared_ptr<CacheClient> cc) { int32_t shuffle, int32_t num_shards, int32_t shard_id) {
std::shared_ptr<CLUENode> clue_node = std::shared_ptr<CLUENode> clue_node =
std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples, toShuffleMode(shuffle), std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples,
num_shards, shard_id, toDatasetCache(std::move(cc))); toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(clue_node->ValidateParams()); THROW_IF_ERROR(clue_node->ValidateParams());
return clue_node; return clue_node;
})); }));
@ -159,10 +160,9 @@ PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode", (void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode",
"to create a CocoNode") "to create a CocoNode")
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task, .def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task,
bool decode, py::handle sampler, std::shared_ptr<CacheClient> cc) { bool decode, py::handle sampler) {
std::shared_ptr<CocoNode> coco = std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr);
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(coco->ValidateParams()); THROW_IF_ERROR(coco->ValidateParams());
return coco; return coco;
})); }));
@ -172,10 +172,10 @@ PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) {
(void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode") (void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode")
.def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults, .def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults,
std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle, std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle,
int32_t num_shards, int32_t shard_id, std::shared_ptr<CacheClient> cc) { int32_t num_shards, int32_t shard_id) {
auto csv = std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults), auto csv =
column_names, num_samples, toShuffleMode(shuffle), std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults), column_names,
num_shards, shard_id, toDatasetCache(std::move(cc))); num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(csv->ValidateParams()); THROW_IF_ERROR(csv->ValidateParams());
return csv; return csv;
})); }));
@ -205,12 +205,12 @@ PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
(void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>( (void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
*m, "ImageFolderNode", "to create an ImageFolderNode") *m, "ImageFolderNode", "to create an ImageFolderNode")
.def(py::init([](std::string dataset_dir, bool decode, py::handle sampler, py::list extensions, .def(py::init([](std::string dataset_dir, bool decode, py::handle sampler, py::list extensions,
py::dict class_indexing, std::shared_ptr<CacheClient> cc) { py::dict class_indexing) {
// Don't update recursive to true // Don't update recursive to true
bool recursive = false; // Will be removed in future PR bool recursive = false; // Will be removed in future PR
auto imagefolder = std::make_shared<ImageFolderNode>( auto imagefolder = std::make_shared<ImageFolderNode>(dataset_dir, decode, toSamplerObj(sampler),
dataset_dir, decode, toSamplerObj(sampler), recursive, toStringSet(extensions), recursive, toStringSet(extensions),
toStringMap(class_indexing), toDatasetCache(std::move(cc))); toStringMap(class_indexing), nullptr);
THROW_IF_ERROR(imagefolder->ValidateParams()); THROW_IF_ERROR(imagefolder->ValidateParams());
return imagefolder; return imagefolder;
})); }));
@ -220,10 +220,9 @@ PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) {
(void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode", (void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode",
"to create a ManifestNode") "to create a ManifestNode")
.def(py::init([](std::string dataset_file, std::string usage, py::handle sampler, .def(py::init([](std::string dataset_file, std::string usage, py::handle sampler,
py::dict class_indexing, bool decode, std::shared_ptr<CacheClient> cc) { py::dict class_indexing, bool decode) {
auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler), auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler),
toStringMap(class_indexing), decode, toStringMap(class_indexing), decode, nullptr);
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(manifest->ValidateParams()); THROW_IF_ERROR(manifest->ValidateParams());
return manifest; return manifest;
})); }));
@ -261,28 +260,25 @@ PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) { PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) {
(void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode", (void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode",
"to create an MnistNode") "to create an MnistNode")
.def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler, .def(py::init([](std::string dataset_dir, std::string usage, py::handle sampler) {
std::shared_ptr<CacheClient> cc) { auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler),
toDatasetCache(std::move(cc)));
THROW_IF_ERROR(mnist->ValidateParams()); THROW_IF_ERROR(mnist->ValidateParams());
return mnist; return mnist;
})); }));
})); }));
PYBIND_REGISTER( PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) {
RandomNode, 2, ([](const py::module *m) { (void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode",
(void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", "to create a RandomNode") "to create a RandomNode")
.def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, py::list columns_list, .def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, py::list columns_list) {
std::shared_ptr<CacheClient> cc) {
auto random_node = auto random_node =
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc))); std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
THROW_IF_ERROR(random_node->ValidateParams()); THROW_IF_ERROR(random_node->ValidateParams());
return random_node; return random_node;
})) }))
.def(py::init([](int32_t total_rows, std::string schema, py::list columns_list, std::shared_ptr<CacheClient> cc) { .def(py::init([](int32_t total_rows, std::string schema, py::list columns_list) {
auto random_node = auto random_node =
std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc))); std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), nullptr);
THROW_IF_ERROR(random_node->ValidateParams()); THROW_IF_ERROR(random_node->ValidateParams());
return random_node; return random_node;
})); }));
@ -292,10 +288,10 @@ PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) {
(void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode", (void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode",
"to create a TextFileNode") "to create a TextFileNode")
.def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards, .def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards,
int32_t shard_id, std::shared_ptr<CacheClient> cc) { int32_t shard_id) {
std::shared_ptr<TextFileNode> textfile_node = std::make_shared<TextFileNode>( std::shared_ptr<TextFileNode> textfile_node =
toStringVector(dataset_files), num_samples, toShuffleMode(shuffle), num_shards, shard_id, std::make_shared<TextFileNode>(toStringVector(dataset_files), num_samples,
toDatasetCache(std::move(cc))); toShuffleMode(shuffle), num_shards, shard_id, nullptr);
THROW_IF_ERROR(textfile_node->ValidateParams()); THROW_IF_ERROR(textfile_node->ValidateParams());
return textfile_node; return textfile_node;
})); }));
@ -306,19 +302,19 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
"to create a TFRecordNode") "to create a TFRecordNode")
.def(py::init([](py::list dataset_files, std::shared_ptr<SchemaObj> schema, py::list columns_list, .def(py::init([](py::list dataset_files, std::shared_ptr<SchemaObj> schema, py::list columns_list,
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id, int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id,
bool shard_equal_rows, std::shared_ptr<CacheClient> cc) { bool shard_equal_rows) {
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>( std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples, toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, toDatasetCache(std::move(cc))); toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr);
THROW_IF_ERROR(tfrecord->ValidateParams()); THROW_IF_ERROR(tfrecord->ValidateParams());
return tfrecord; return tfrecord;
})) }))
.def(py::init([](py::list dataset_files, std::string schema, py::list columns_list, .def(py::init([](py::list dataset_files, std::string schema, py::list columns_list,
int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id, int64_t num_samples, int32_t shuffle, int32_t num_shards, int32_t shard_id,
bool shard_equal_rows, std::shared_ptr<CacheClient> cc) { bool shard_equal_rows) {
std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>( std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>(
toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples, toStringVector(dataset_files), schema, toStringVector(columns_list), num_samples,
toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, toDatasetCache(std::move(cc))); toShuffleMode(shuffle), num_shards, shard_id, shard_equal_rows, nullptr);
THROW_IF_ERROR(tfrecord->ValidateParams()); THROW_IF_ERROR(tfrecord->ValidateParams());
return tfrecord; return tfrecord;
})); }));
@ -326,12 +322,10 @@ PYBIND_REGISTER(TFRecordNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) { PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode") (void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
.def( .def(py::init([](std::string dataset_dir, std::string task, std::string usage,
py::init([](std::string dataset_dir, std::string task, std::string usage, py::dict class_indexing, py::dict class_indexing, bool decode, py::handle sampler) {
bool decode, py::handle sampler, std::shared_ptr<CacheClient> cc) { std::shared_ptr<VOCNode> voc = std::make_shared<VOCNode>(
std::shared_ptr<VOCNode> voc = dataset_dir, task, usage, toStringMap(class_indexing), decode, toSamplerObj(sampler), nullptr);
std::make_shared<VOCNode>(dataset_dir, task, usage, toStringMap(class_indexing), decode,
toSamplerObj(sampler), toDatasetCache(std::move(cc)));
THROW_IF_ERROR(voc->ValidateParams()); THROW_IF_ERROR(voc->ValidateParams());
return voc; return voc;
})); }));
@ -439,11 +433,11 @@ PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) { PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
(void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode") (void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list operations, py::list input_columns, .def(py::init([](std::shared_ptr<DatasetNode> self, py::list operations, py::list input_columns,
py::list output_columns, py::list project_columns, std::shared_ptr<CacheClient> cc, py::list output_columns, py::list project_columns,
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks) { std::vector<std::shared_ptr<PyDSCallback>> py_callbacks) {
auto map = std::make_shared<MapNode>( auto map = std::make_shared<MapNode>(
self, std::move(toTensorOperations(operations)), toStringVector(input_columns), self, std::move(toTensorOperations(operations)), toStringVector(input_columns),
toStringVector(output_columns), toStringVector(project_columns), toDatasetCache(std::move(cc)), toStringVector(output_columns), toStringVector(project_columns), nullptr,
std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end())); std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end()));
THROW_IF_ERROR(map->ValidateParams()); THROW_IF_ERROR(map->ValidateParams());
return map; return map;

@ -212,6 +212,11 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
return shared_from_this(); return shared_from_this();
} }
std::shared_ptr<DatasetNode> DatasetNode::SetDatasetCache(const std::shared_ptr<DatasetCache> &cache) {
cache_ = cache;
return shared_from_this();
}
DatasetNode::DatasetNode() DatasetNode::DatasetNode()
: cache_(nullptr), : cache_(nullptr),
parent_(nullptr), parent_(nullptr),

@ -260,6 +260,11 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Shared pointer to the original object /// \return Shared pointer to the original object
std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers); std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers);
/// \brief Setter function for DatasetCache
/// \param[in] cache Shared pointer to DatasetCache
/// \return Shared pointer to the original object
std::shared_ptr<DatasetNode> SetDatasetCache(const std::shared_ptr<DatasetCache> &cache);
/// \brief A helper templated function for casting "this" pointer to shared_ptr<derived> /// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
/// Similar to shared_from_this, except this one will give you the derived class as shared_ptr /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr
/// \return A shared_ptr casted to the derived class /// \return A shared_ptr casted to the derived class

@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde
__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers', __all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load', 'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load',
'get_callback_timeout', 'set_auto_num_workers', 'get_auto_num_workers'] 'get_callback_timeout', 'set_auto_num_workers', 'get_auto_num_workers', '_init_device_info']
INT32_MAX = 2147483647 INT32_MAX = 2147483647
UINT32_MAX = 4294967295 UINT32_MAX = 4294967295

File diff suppressed because it is too large Load Diff

@ -83,7 +83,7 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
shuffle = True shuffle = True
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle is not specified, sharding disabled, use random sampler # If shuffle is not specified, sharding disabled, use random sampler
if num_samples is not None: if num_samples is not None and num_samples != 0:
return RandomSampler(replacement=True, num_samples=num_samples) return RandomSampler(replacement=True, num_samples=num_samples)
return RandomSampler(num_samples=num_samples) return RandomSampler(num_samples=num_samples)
if shuffle is True: if shuffle is True:

@ -123,6 +123,35 @@ class Compose:
""" """
return util.compose(self.transforms, *args) return util.compose(self.transforms, *args)
@staticmethod
def reduce(operations):
"""
Wraps adjacent Python operations in a Compose to allow mixing of Python and C++ operations
Args:
operations (list): list of tensor operations
Returns:
list, the reduced list of operations
"""
#
if len(operations) == 1:
return operations
new_ops, start_ind, end_ind = [], 0, 0
for i, op in enumerate(operations):
if str(op).find("c_transform") >= 0:
# reset counts
if start_ind != end_ind:
new_ops.append(Compose(operations[start_ind:end_ind]))
new_ops.append(op)
start_ind, end_ind = i + 1, i + 1
else:
end_ind += 1
# do additional check in case the last operation is a Python operation
if start_ind != end_ind:
new_ops.append(Compose(operations[start_ind:end_ind]))
return new_ops
class RandomApply: class RandomApply:
""" """

Loading…
Cancel
Save