fix generator or user defined sampler len method unmatch iter method

pull/11300/head
anzhengqi 4 years ago
parent 5e5489d59f
commit 50b783ee13

@ -49,14 +49,16 @@ Status GeneratorOp::Builder::Build(std::shared_ptr<GeneratorOp> *ptr) {
GeneratorOp::GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size,
int32_t connector_size)
int32_t connector_size, int64_t pre_counter_size)
: PipelineOp(connector_size),
generator_function_(generator_function),
column_names_(column_names),
column_types_(column_types),
prefetch_size_(prefetch_size),
buffer_size_(buffer_size),
buffer_id_(0) {}
pre_counter_size_(pre_counter_size),
buffer_id_(0),
generator_counter_(0) {}
GeneratorOp::~GeneratorOp() { this->Dealloc(); }
@ -146,6 +148,7 @@ Status GeneratorOp::FillBuffer(TensorQTable *tt) {
TensorRow row;
RETURN_IF_NOT_OK(PyRowToTensorRow(generator_.attr("__next__")(), &row));
tt->push_back(std::move(row));
generator_counter_++;
}
return Status::OK();
}
@ -209,6 +212,13 @@ Status GeneratorOp::operator()() {
if (!eoe) {
return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what());
}
if (pre_counter_size_ != -1 && pre_counter_size_ != generator_counter_) {
std::stringstream ss;
ss << "The actual amount of data read from generator " << generator_counter_
<< " is different from generator.len " << pre_counter_size_
<< ", you should adjust generator.len to make them match.";
return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, ss.str());
}
}
}
if (fetched_table->size() > 0) {
@ -254,6 +264,7 @@ Status GeneratorOp::Reset() {
// Wake up master thread
wp_.Set();
}
generator_counter_ = 0;
return Status(StatusCode::kOK, "GeneratorOp Reset Succeed");
}

@ -93,7 +93,8 @@ class GeneratorOp : public PipelineOp {
};
GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size);
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size,
int64_t pre_counter_size = 0);
~GeneratorOp();
@ -142,6 +143,8 @@ class GeneratorOp : public PipelineOp {
std::vector<DataType> column_types_;
int32_t prefetch_size_;
int32_t buffer_size_;
int64_t pre_counter_size_;
int64_t generator_counter_;
py::object generator_;
int32_t buffer_id_;

@ -46,6 +46,7 @@ std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
} else {
node = std::make_shared<GeneratorNode>(generator_function_, schema_);
}
node->SetGeneratorDatasetSize(dataset_size_);
return node;
}
@ -72,7 +73,7 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_
// GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by
// GeneratorOp internally. Here it is given a zero which is the default in generator builder
std::shared_ptr<GeneratorOp> op = std::make_shared<GeneratorOp>(generator_function_, column_names_, column_types_, 0,
rows_per_buffer_, connector_que_size_);
rows_per_buffer_, connector_que_size_, dataset_size_);
// Init() is called in builder when generator is built. Here, since we are getting away from the builder class, init
// needs to be called when the op is built. The caveat is that Init needs to be made public (before it is private).

@ -2663,7 +2663,7 @@ class ConcatDataset(Dataset):
tem_sampler = copy.deepcopy(sampler)
tem_sampler.set_offset(cumulative_samples_nums)
child.sampler = tem_sampler
child.use_sampler(tem_sampler)
cumulative_samples_nums += self.children_sizes_[index]
cumulative_samples_nums %= sampler.num_shards
@ -3808,6 +3808,8 @@ class GeneratorDataset(MappableDataset):
self.dataset_size = math.ceil(len(self.source) / self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if self.num_samples is not None and self.num_samples < rows_from_sampler:
rows_from_sampler = self.num_samples
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler

Loading…
Cancel
Save