diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc index fbd6ec87b7..55b4709d6b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -193,9 +193,9 @@ Status GeneratorOp::operator()() { TaskManager::FindMe()->Post(); RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); std::unique_ptr fetched_buffer; + int64_t num_rows_sampled = sampler_ ? sampler_->CalculateNumSamples(num_rows_) : num_rows_; RETURN_IF_NOT_OK(Init()); - int64_t num_rows_sampled = sampler_ ? sampler_->CalculateNumSamples(num_rows_) : num_rows_; bool eof = false; while (!eof) { // Create new buffer each iteration diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index c8f478c818..ddd043097d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -184,7 +184,6 @@ int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) { if (device_id_ < remainder) shard_size++; if (device_id_ < offset_) shard_size--; } else { - offset_ = 0; shard_size = (child_num_rows + num_devices_ - 1) / num_devices_; } // add 1 to an empty shard diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc index 4ab4350e2a..3b839a85aa 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.cc @@ -42,7 +42,12 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector< GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr &schema, int64_t source_len, std::shared_ptr sampler) - : MappableSourceNode(), generator_function_(generator_function), schema_(schema), reset_ancestor_(nullptr) {} + : MappableSourceNode(), + generator_function_(generator_function), + schema_(schema), + reset_ancestor_(nullptr), + sampler_(std::move(sampler)), + source_len_(source_len) {} std::shared_ptr GeneratorNode::Copy() { std::shared_ptr node; diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 3c2986b0bc..72dfd50ba1 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2233,6 +2233,7 @@ _GLOBAL_PYFUNC_LIST = [] _OP_NAME = dict() _OP_PROCESS = dict() + # Pyfunc worker init function # Python multiprocessing library forbid sending lambda function through pipe. # This init function allow us to add all Python function to a global collection and then fork afterwards. @@ -3781,6 +3782,8 @@ class GeneratorDataset(MappableDataset): try: new_op.sampler = None new_op.sample_fn = sample_fn + new_op.source_len = min(new_op.source_len, + new_op.num_samples) if new_op.num_samples is not None else new_op.source_len iter(self.source) except TypeError: # Use generator function if input callable