datasets.py cleanup

pull/11971/head
hesham 4 years ago
parent 09db51a797
commit 4bbc3445f1

@ -212,6 +212,11 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
return shared_from_this();
}
std::shared_ptr<DatasetNode> DatasetNode::SetDatasetCache(const std::shared_ptr<DatasetCache> &cache) {
cache_ = cache;
return shared_from_this();
}
DatasetNode::DatasetNode()
: cache_(nullptr),
parent_(nullptr),

@ -260,6 +260,11 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Shared pointer to the original object
std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers);
/// \brief Setter function for DatasetCache
/// \param[in] cache Shared pointer to DatasetCache
/// \return Shared pointer to the original object
std::shared_ptr<DatasetNode> SetDatasetCache(const std::shared_ptr<DatasetCache> &cache);
/// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
/// Similar to shared_from_this, except this one will give you the derived class as shared_ptr
/// \return A shared_ptr casted to the derived class

@ -23,7 +23,7 @@ import mindspore._c_dataengine as cde
__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load',
'get_callback_timeout', 'set_auto_num_workers', 'get_auto_num_workers']
'get_callback_timeout', 'set_auto_num_workers', 'get_auto_num_workers', '_init_device_info']
INT32_MAX = 2147483647
UINT32_MAX = 4294967295

File diff suppressed because it is too large Load Diff

@ -83,7 +83,7 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
shuffle = True
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle is not specified, sharding disabled, use random sampler
if num_samples is not None:
if num_samples is not None and num_samples != 0:
return RandomSampler(replacement=True, num_samples=num_samples)
return RandomSampler(num_samples=num_samples)
if shuffle is True:

@ -123,6 +123,35 @@ class Compose:
"""
return util.compose(self.transforms, *args)
@staticmethod
def reduce(operations):
"""
Wraps adjacent Python operations in a Compose to allow mixing of Python and C++ operations
Args:
operations (list): list of tensor operations
Returns:
list, the reduced list of operations
"""
#
if len(operations) == 1:
return operations
new_ops, start_ind, end_ind = [], 0, 0
for i, op in enumerate(operations):
if str(op).find("c_transform") >= 0:
# reset counts
if start_ind != end_ind:
new_ops.append(Compose(operations[start_ind:end_ind]))
new_ops.append(op)
start_ind, end_ind = i + 1, i + 1
else:
end_ind += 1
# do additional check in case the last operation is a Python operation
if start_ind != end_ind:
new_ops.append(Compose(operations[start_ind:end_ind]))
return new_ops
class RandomApply:
"""

Loading…
Cancel
Save