|
|
|
@ -17,6 +17,7 @@
|
|
|
|
|
#include <utility>
|
|
|
|
|
|
|
|
|
|
#include "common/utils.h"
|
|
|
|
|
#include "dataset/core/config_manager.h"
|
|
|
|
|
#include "dataset/engine/data_buffer.h"
|
|
|
|
|
#include "dataset/engine/datasetops/take_op.h"
|
|
|
|
|
#include "dataset/engine/db_connector.h"
|
|
|
|
@ -25,7 +26,10 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace dataset {
|
|
|
|
|
// Builder constructor. Creates the builder object.
|
|
|
|
|
TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {}
|
|
|
|
|
TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {
|
|
|
|
|
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
|
|
|
|
builder_op_connector_size_ = cfg->op_connector_size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status TakeOp::Builder::SanityCheck() const {
|
|
|
|
|
if (build_max_takes_ <= 0) {
|
|
|
|
@ -38,12 +42,13 @@ Status TakeOp::Builder::SanityCheck() const {
|
|
|
|
|
// The builder "build" method creates the final object.
|
|
|
|
|
Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) {
|
|
|
|
|
RETURN_IF_NOT_OK(SanityCheck());
|
|
|
|
|
*ptr = std::make_shared<TakeOp>(build_max_takes_);
|
|
|
|
|
*ptr = std::make_shared<TakeOp>(build_max_takes_, builder_op_connector_size_);
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Constructor of the TakeOp.
|
|
|
|
|
TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {}
|
|
|
|
|
TakeOp::TakeOp(int32_t count, int32_t op_connector_size)
|
|
|
|
|
: PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {}
|
|
|
|
|
|
|
|
|
|
// A print method typically used for debugging
|
|
|
|
|
void TakeOp::Print(std::ostream &out, bool show_all) const {
|
|
|
|
@ -62,59 +67,41 @@ void TakeOp::Print(std::ostream &out, bool show_all) const {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// This function will be call muti times to returns the buffer, when meet required max take count or meet
|
|
|
|
|
// EOF buffer then this will stop.
|
|
|
|
|
Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
|
|
|
|
|
if (child_.empty()) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("TakeOp can't be the leaf node.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Main entry point for Take
|
|
|
|
|
Status TakeOp::operator()() {
|
|
|
|
|
TaskManager::FindMe()->Post();
|
|
|
|
|
std::unique_ptr<DataBuffer> buf;
|
|
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
|
|
|
|
|
|
|
|
|
|
bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat);
|
|
|
|
|
if (take_count_ == max_takes_) {
|
|
|
|
|
if (state_ == OpState::kDeOpRunning) {
|
|
|
|
|
MS_LOG(DEBUG) << "Meet max count and push-back eoe buffer.";
|
|
|
|
|
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
|
|
|
|
*p_buffer = std::move(eoe_buffer);
|
|
|
|
|
state_ = OpState::kDeOpIdle;
|
|
|
|
|
|
|
|
|
|
// Reset the count and drain
|
|
|
|
|
if (!last_repeat) {
|
|
|
|
|
take_count_ = 0;
|
|
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
|
|
|
|
|
while (!buf->eoe() && !buf->eof()) {
|
|
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
|
|
|
|
|
}
|
|
|
|
|
while (buf->eof() == false) {
|
|
|
|
|
if (take_count_ == max_takes_) {
|
|
|
|
|
// Do drain Operation
|
|
|
|
|
while (!buf->eoe() && !buf->eof()) {
|
|
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
|
|
|
|
|
}
|
|
|
|
|
} else if (state_ == OpState::kDeOpIdle) {
|
|
|
|
|
MS_LOG(DEBUG) << "Meet max count and push-back eof buffer.";
|
|
|
|
|
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
|
|
|
|
*p_buffer = std::move(eof_buffer);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Loop until non EOE is received
|
|
|
|
|
if (buf->eoe()) {
|
|
|
|
|
take_count_ = 0;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(WARNING) << "Invalid OpState: " << state_;
|
|
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
|
|
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
|
|
|
|
|
// Loop until non EOE is received
|
|
|
|
|
if (buf->eoe()) {
|
|
|
|
|
take_count_ = 0;
|
|
|
|
|
*p_buffer = std::move(buf);
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Check if the last buf is next eof
|
|
|
|
|
if (buf->eof()) {
|
|
|
|
|
*p_buffer = std::move(buf);
|
|
|
|
|
return Status::OK();
|
|
|
|
|
// Get buffer and push back when take_count is still small
|
|
|
|
|
if (take_count_ < max_takes_) {
|
|
|
|
|
std::unique_ptr<DataBuffer> p_buffer;
|
|
|
|
|
RETURN_IF_NOT_OK(FillBuffer(&buf, &p_buffer));
|
|
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(p_buffer)));
|
|
|
|
|
}
|
|
|
|
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get buffer and push back when take_count is still small
|
|
|
|
|
if (take_count_ < max_takes_) {
|
|
|
|
|
RETURN_IF_NOT_OK(FillBuffer(&buf, p_buffer));
|
|
|
|
|
}
|
|
|
|
|
take_count_ = 0;
|
|
|
|
|
MS_LOG(DEBUG) << "Meet the end and push-back eof buffer.";
|
|
|
|
|
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
|
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -139,13 +126,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Class functor operator () override.
|
|
|
|
|
// Most dataset ops operate by launching a thread (see ExecutionTree).
|
|
|
|
|
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
|
|
|
|
|
// functor since this op runs inlined inside another operator. The function is overloaded to
|
|
|
|
|
// ensure that it is not called by mistake (it will generate an error).
|
|
|
|
|
Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. TakeOp is an inlined operator."); }
|
|
|
|
|
|
|
|
|
|
Status TakeOp::PrepareNodePostAction() {
|
|
|
|
|
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
|
|
|
|
|
tree_->AddToRepeatStack(shared_from_this());
|
|
|
|
|