!12154 Fix samplers bugs

From: @hfarahat
Reviewed-by: @heleiwang,@liucunwei
Signed-off-by: @liucunwei
pull/12154/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit bf528c6817

@ -193,9 +193,9 @@ Status GeneratorOp::operator()() {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks()));
std::unique_ptr<DataBuffer> fetched_buffer;
int64_t num_rows_sampled = sampler_ ? sampler_->CalculateNumSamples(num_rows_) : num_rows_;
RETURN_IF_NOT_OK(Init());
int64_t num_rows_sampled = sampler_ ? sampler_->CalculateNumSamples(num_rows_) : num_rows_;
bool eof = false;
while (!eof) {
// Create new buffer each iteration

@ -184,7 +184,6 @@ int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) {
if (device_id_ < remainder) shard_size++;
if (device_id_ < offset_) shard_size--;
} else {
offset_ = 0;
shard_size = (child_num_rows + num_devices_ - 1) / num_devices_;
}
// add 1 to an empty shard

@ -42,7 +42,12 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema,
int64_t source_len, std::shared_ptr<SamplerObj> sampler)
: MappableSourceNode(), generator_function_(generator_function), schema_(schema), reset_ancestor_(nullptr) {}
: MappableSourceNode(),
generator_function_(generator_function),
schema_(schema),
reset_ancestor_(nullptr),
sampler_(std::move(sampler)),
source_len_(source_len) {}
std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
std::shared_ptr<GeneratorNode> node;

@ -2233,6 +2233,7 @@ _GLOBAL_PYFUNC_LIST = []
_OP_NAME = dict()
_OP_PROCESS = dict()
# Pyfunc worker init function
# Python multiprocessing library forbid sending lambda function through pipe.
# This init function allow us to add all Python function to a global collection and then fork afterwards.
@ -3781,6 +3782,8 @@ class GeneratorDataset(MappableDataset):
try:
new_op.sampler = None
new_op.sample_fn = sample_fn
new_op.source_len = min(new_op.source_len,
new_op.num_samples) if new_op.num_samples is not None else new_op.source_len
iter(self.source)
except TypeError:
# Use generator function if input callable

Loading…
Cancel
Save