|
|
|
@ -16,6 +16,7 @@
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include <utility>
|
|
|
|
|
|
|
|
|
|
#include "dataset/core/config_manager.h"
|
|
|
|
|
#include "dataset/engine/data_buffer.h"
|
|
|
|
|
#include "dataset/engine/datasetops/skip_op.h"
|
|
|
|
|
#include "dataset/engine/db_connector.h"
|
|
|
|
@ -26,7 +27,10 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace dataset {
|
|
|
|
|
// Builder constructor. Creates the builder object.
|
|
|
|
|
SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {}
|
|
|
|
|
SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {
|
|
|
|
|
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
|
|
|
|
builder_op_connector_size_ = cfg->op_connector_size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status SkipOp::Builder::SanityCheck() const {
|
|
|
|
|
if (build_max_skips_ < 0) {
|
|
|
|
@ -39,12 +43,13 @@ Status SkipOp::Builder::SanityCheck() const {
|
|
|
|
|
// The builder "build" method creates the final object.
|
|
|
|
|
Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) {
|
|
|
|
|
RETURN_IF_NOT_OK(SanityCheck());
|
|
|
|
|
*ptr = std::make_shared<SkipOp>(build_max_skips_);
|
|
|
|
|
*ptr = std::make_shared<SkipOp>(build_max_skips_, builder_op_connector_size_);
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Constructor of the SkipOp.
|
|
|
|
|
SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {}
|
|
|
|
|
SkipOp::SkipOp(int32_t count, int32_t op_connector_size)
|
|
|
|
|
: PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {}
|
|
|
|
|
|
|
|
|
|
// Destructor
|
|
|
|
|
SkipOp::~SkipOp() {}
|
|
|
|
@ -59,49 +64,6 @@ void SkipOp::Print(std::ostream &out, bool show_all) const {
|
|
|
|
|
<< "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Since the buffer may contain multi rows, this function will drop the rows
|
|
|
|
|
// that need to skip in it, and then return the buffer.
|
|
|
|
|
Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
|
|
|
|
|
if (child_.empty()) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("SkipOp can't be the leaf node.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<DataBuffer> buf;
|
|
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
|
|
|
|
|
|
|
|
|
|
// Drop first max_skips_ rows
|
|
|
|
|
while (skip_count_ < max_skips_) {
|
|
|
|
|
if (buf->eoe() || buf->eof()) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Consider the rows of buffer more than 1
|
|
|
|
|
TensorRow drop_row;
|
|
|
|
|
int row_num = buf->NumRows();
|
|
|
|
|
int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
|
|
|
|
|
skip_count_ += drop_num;
|
|
|
|
|
for (int i = 0; i < drop_num; i++) {
|
|
|
|
|
RETURN_IF_NOT_OK(buf->PopRow(&drop_row));
|
|
|
|
|
}
|
|
|
|
|
if (buf->NumRows() == 0) {
|
|
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handling eoe
|
|
|
|
|
if (buf->eoe()) {
|
|
|
|
|
RETURN_IF_NOT_OK(EoeReceived(worker_id));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handling eof
|
|
|
|
|
if (buf->eof()) {
|
|
|
|
|
RETURN_IF_NOT_OK(EofReceived(worker_id));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*p_buffer = std::move(buf);
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Base-class override for handling cases when an eoe is received.
|
|
|
|
|
Status SkipOp::EoeReceived(int32_t worker_id) {
|
|
|
|
|
skip_count_ = 0;
|
|
|
|
@ -109,13 +71,45 @@ Status SkipOp::EoeReceived(int32_t worker_id) {
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Class functor operator () override.
|
|
|
|
|
// Most dataset ops operate by launching a thread (see ExecutionTree).
|
|
|
|
|
// However, the SkipOp is defined as a inlined operator, so it is invalid to
|
|
|
|
|
// launch the functor since this op runs inlined inside another operator. The
|
|
|
|
|
// function is overloaded to ensure that it is not called by mistake (it will
|
|
|
|
|
// generate an error).
|
|
|
|
|
Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
|
|
|
|
|
// main entry point for skip
|
|
|
|
|
Status SkipOp::operator()() {
|
|
|
|
|
TaskManager::FindMe()->Post();
|
|
|
|
|
std::unique_ptr<DataBuffer> curr_buffer;
|
|
|
|
|
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
|
|
|
|
|
while (curr_buffer->eof() == false) {
|
|
|
|
|
// Reset count
|
|
|
|
|
skip_count_ = 0;
|
|
|
|
|
while (curr_buffer->eoe() == false) {
|
|
|
|
|
// Drop first count rows
|
|
|
|
|
while (skip_count_ < max_skips_) {
|
|
|
|
|
if (curr_buffer->eoe() || curr_buffer->eof()) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
// Consider the rows of buffer more than one
|
|
|
|
|
TensorRow drop_row;
|
|
|
|
|
int row_num = curr_buffer->NumRows();
|
|
|
|
|
int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
|
|
|
|
|
skip_count_ += drop_num;
|
|
|
|
|
for (int i = 0; i < drop_num; i++) {
|
|
|
|
|
RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row));
|
|
|
|
|
}
|
|
|
|
|
if (curr_buffer->NumRows() == 0) {
|
|
|
|
|
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer)));
|
|
|
|
|
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
|
|
|
|
|
}
|
|
|
|
|
// we got eoe, now try again until we got eof
|
|
|
|
|
MS_LOG(DEBUG) << "Skip operator EOE Received.";
|
|
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
|
|
|
|
|
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Skip operator EOF Received.";
|
|
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Base-class override for handling cases when an eof is received.
|
|
|
|
|
Status SkipOp::EofReceived(int32_t worker_id) {
|
|
|
|
|