diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc index 872c4c27c5..7e6055027e 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc @@ -17,6 +17,7 @@ #include #include "common/utils.h" +#include "dataset/core/config_manager.h" #include "dataset/engine/data_buffer.h" #include "dataset/engine/datasetops/take_op.h" #include "dataset/engine/db_connector.h" @@ -25,7 +26,10 @@ namespace mindspore { namespace dataset { // Builder constructor. Creates the builder object. -TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {} +TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} Status TakeOp::Builder::SanityCheck() const { if (build_max_takes_ <= 0) { @@ -38,12 +42,13 @@ Status TakeOp::Builder::SanityCheck() const { // The builder "build" method creates the final object. Status TakeOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_max_takes_); + *ptr = std::make_shared(build_max_takes_, builder_op_connector_size_); return Status::OK(); } // Constructor of the TakeOp. -TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {} +TakeOp::TakeOp(int32_t count, int32_t op_connector_size) + : PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {} // A print method typically used for debugging void TakeOp::Print(std::ostream &out, bool show_all) const { @@ -62,59 +67,41 @@ void TakeOp::Print(std::ostream &out, bool show_all) const { } } -// This function will be call muti times to returns the buffer, when meet required max take count or meet -// EOF buffer then this will stop. -Status TakeOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { - if (child_.empty()) { - RETURN_STATUS_UNEXPECTED("TakeOp can't be the leaf node."); - } - +// Main entry point for Take +Status TakeOp::operator()() { + TaskManager::FindMe()->Post(); std::unique_ptr buf; + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); - bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat); - if (take_count_ == max_takes_) { - if (state_ == OpState::kDeOpRunning) { - MS_LOG(DEBUG) << "Meet max count and push-back eoe buffer."; - auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - *p_buffer = std::move(eoe_buffer); - state_ = OpState::kDeOpIdle; - - // Reset the count and drain - if (!last_repeat) { - take_count_ = 0; - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); - while (!buf->eoe() && !buf->eof()) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); - } + while (buf->eof() == false) { + if (take_count_ == max_takes_) { + // Do drain Operation + while (!buf->eoe() && !buf->eof()) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); } - } else if (state_ == OpState::kDeOpIdle) { - MS_LOG(DEBUG) << "Meet max count and push-back eof buffer."; - auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - *p_buffer = std::move(eof_buffer); + } + + // Loop until non EOE is received + if (buf->eoe()) { take_count_ = 0; - } else { - MS_LOG(WARNING) << "Invalid OpState: " << state_; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); + continue; } - return Status::OK(); - } - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); - // Loop until non EOE is received - if (buf->eoe()) { - take_count_ = 0; - *p_buffer = std::move(buf); - return Status::OK(); - } - // Check if the last buf is next eof - if (buf->eof()) { - *p_buffer = std::move(buf); - return Status::OK(); + // Get buffer and push back when take_count is still small + if (take_count_ < max_takes_) { + std::unique_ptr p_buffer; + RETURN_IF_NOT_OK(FillBuffer(&buf, &p_buffer)); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(p_buffer))); + } + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); } - // Get buffer and push back when take_count is still small - if (take_count_ < max_takes_) { - RETURN_IF_NOT_OK(FillBuffer(&buf, p_buffer)); - } + take_count_ = 0; + MS_LOG(DEBUG) << "Meet the end and push-back eof buffer."; + auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); return Status::OK(); } @@ -139,13 +126,6 @@ Status TakeOp::FillBuffer(std::unique_ptr *buffer, std::unique_ptrAddToRepeatStack(shared_from_this()); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.h b/mindspore/ccsrc/dataset/engine/datasetops/take_op.h index 02218cf610..f70a1e91a3 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/take_op.h @@ -45,6 +45,7 @@ class TakeOp : public PipelineOp { private: int32_t build_max_takes_; + int32_t builder_op_connector_size_; Status SanityCheck() const; }; @@ -52,7 +53,7 @@ class TakeOp : public PipelineOp { // Constructor of the TakeOp. // @note The builder class should be used to call it // @param count - The number of takes to do - explicit TakeOp(int32_t count); + explicit TakeOp(int32_t count, int32_t op_connector_size); // Destructor ~TakeOp() = default; @@ -72,23 +73,11 @@ class TakeOp : public PipelineOp { return out; } - // Class functor operator () override. - // Most dataset ops operate by launching a thread (see ExecutionTree). - // However, the TakeOp is defined as a inlined operator, so it is invalid to launch the - // functor since this op runs inlined inside another operator. The function is overloaded to - // ensure that it is not called by mistake (it will generate an error). + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work // @return Status - The error code return Status operator()() override; - // Gets a buffer from the child node. The caller is typically our parent node. - // @note This function sets the `retryIfEoe` flag when popping from the child connector. This way, - // this function will retry to pop the connector again and will get the non-EOE buffer if any. - // @param p_buffer - output pointer to the buffer that it will fetch. - // @param worker_id - The worker id - // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. - // @return Status - The error code return - Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) override; - // During tree prepare phase, operators may have specific post-operations to perform depending on // their role. // @notes Derived versions of this function should always call it's superclass version first diff --git a/tests/ut/python/dataset/test_take.py b/tests/ut/python/dataset/test_take.py index ed71f67e26..64efc7a785 100644 --- a/tests/ut/python/dataset/test_take.py +++ b/tests/ut/python/dataset/test_take.py @@ -30,6 +30,12 @@ def generator_10(): yield np.array([i]), +def filter_func_ge(data): + if data > 3: + return False + return True + + def test_take_01(): """ Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof @@ -297,6 +303,44 @@ def test_take_16(): assert sum([1 for _ in data1]) == 5 +def test_take_17(): + """ + Test take: take first, then do fiter operation + """ + logger.info("test_take_17") + data1 = ds.GeneratorDataset(generator_10, ["data"]) + + data1 = data1.take(8) + data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert i == d[0][0] + + assert sum([1 for _ in data1]) == 4 + + +def test_take_18(): + """ + Test take: take first, then do fiter, skip, batch and repeat operation + """ + logger.info("test_take_18") + data1 = ds.GeneratorDataset(generator_10, ["data"]) + + data1 = data1.take(8) + data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4) + data1 = data1.skip(2) + + data1 = data1.batch(2) + data1 = data1.repeat(2) + + # Here i refers to index, d refers to data element + for i, d in enumerate(data1): + assert 2 == d[0][0] + + assert sum([1 for _ in data1]) == 2 + + if __name__ == '__main__': test_take_01() test_take_02() @@ -314,4 +358,6 @@ if __name__ == '__main__': test_take_14() test_take_15() test_take_16() + test_take_17() + test_take_18() logger.info('== test take operation finished ==') \ No newline at end of file