@ -17,6 +17,7 @@
#include <utility>
#include "common/utils.h"
#include "dataset/core/config_manager.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/db_connector.h"
@ -25,7 +26,10 @@
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {}
TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_op_connector_size_ = cfg->op_connector_size();
Status TakeOp::Builder::SanityCheck() const {
if (build_max_takes_ <= 0) {
@ -38,12 +42,13 @@ Status TakeOp::Builder::SanityCheck() const {
// The builder "build" method creates the final object.
Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) {
*ptr = std::make_shared<TakeOp>(build_max_takes_);
*ptr = std::make_shared<TakeOp>(build_max_takes_, builder_op_connector_size_);
return Status::OK();
// Constructor of the TakeOp.
TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {}
TakeOp::TakeOp(int32_t count, int32_t op_connector_size)
: PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {}
// A print method typically used for debugging
void TakeOp::Print(std::ostream &out, bool show_all) const {
@ -62,59 +67,41 @@ void TakeOp::Print(std::ostream &out, bool show_all) const {
// This function will be call muti times to returns the buffer, when meet required max take count or meet
// EOF buffer then this will stop.
Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
if (child_.empty()) {
RETURN_STATUS_UNEXPECTED("TakeOp can't be the leaf node.");
// Main entry point for Take
Status TakeOp::operator()() {
std::unique_ptr<DataBuffer> buf;
bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat);
if (take_count_ == max_takes_) {
if (state_ == OpState::kDeOpRunning) {
MS_LOG(DEBUG) << "Meet max count and push-back eoe buffer.";
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
*p_buffer = std::move(eoe_buffer);
state_ = OpState::kDeOpIdle;
// Reset the count and drain
if (!last_repeat) {
take_count_ = 0;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
while (!buf->eoe() && !buf->eof()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
while (buf->eof() == false) {
if (take_count_ == max_takes_) {
// Do drain Operation
while (!buf->eoe() && !buf->eof()) {
} else if (state_ == OpState::kDeOpIdle) {
MS_LOG(DEBUG) << "Meet max count and push-back eof buffer.";
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
*p_buffer = std::move(eof_buffer);
// Loop until non EOE is received
if (buf->eoe()) {
take_count_ = 0;
} else {
MS_LOG(WARNING) << "Invalid OpState: " << state_;
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
return Status::OK();
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
// Loop until non EOE is received
if (buf->eoe()) {
take_count_ = 0;
*p_buffer = std::move(buf);
return Status::OK();
// Check if the last buf is next eof
if (buf->eof()) {
*p_buffer = std::move(buf);
return Status::OK();
// Get buffer and push back when take_count is still small
if (take_count_ < max_takes_) {
std::unique_ptr<DataBuffer> p_buffer;
RETURN_IF_NOT_OK(FillBuffer(&buf, &p_buffer));
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(p_buffer)));
// Get buffer and push back when take_count is still small
if (take_count_ < max_takes_) {
RETURN_IF_NOT_OK(FillBuffer(&buf, p_buffer));
take_count_ = 0;
MS_LOG(DEBUG) << "Meet the end and push-back eof buffer.";
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
return Status::OK();
@ -139,13 +126,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
return Status::OK();
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. TakeOp is an inlined operator."); }
Status TakeOp::PrepareNodePostAction() {