!871 dataset: repair take op problem when there exist muti-thread in next node

Merge pull request !871 from ms_yan/take_operator
pull/871/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 67f3d0eb5d

@ -17,6 +17,7 @@
#include <utility>
#include "common/utils.h"
#include "dataset/core/config_manager.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/db_connector.h"
@ -25,7 +26,10 @@
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {}
TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_op_connector_size_ = cfg->op_connector_size();
}
Status TakeOp::Builder::SanityCheck() const {
if (build_max_takes_ <= 0) {
@ -38,12 +42,13 @@ Status TakeOp::Builder::SanityCheck() const {
// The builder "build" method creates the final object.
Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<TakeOp>(build_max_takes_);
*ptr = std::make_shared<TakeOp>(build_max_takes_, builder_op_connector_size_);
return Status::OK();
}
// Constructor of the TakeOp.
TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {}
TakeOp::TakeOp(int32_t count, int32_t op_connector_size)
: PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {}
// A print method typically used for debugging
void TakeOp::Print(std::ostream &out, bool show_all) const {
@ -62,59 +67,41 @@ void TakeOp::Print(std::ostream &out, bool show_all) const {
}
}
// This function will be call muti times to returns the buffer, when meet required max take count or meet
// EOF buffer then this will stop.
Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
if (child_.empty()) {
RETURN_STATUS_UNEXPECTED("TakeOp can't be the leaf node.");
}
// Main entry point for Take
Status TakeOp::operator()() {
TaskManager::FindMe()->Post();
std::unique_ptr<DataBuffer> buf;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat);
if (take_count_ == max_takes_) {
if (state_ == OpState::kDeOpRunning) {
MS_LOG(DEBUG) << "Meet max count and push-back eoe buffer.";
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
*p_buffer = std::move(eoe_buffer);
state_ = OpState::kDeOpIdle;
// Reset the count and drain
if (!last_repeat) {
take_count_ = 0;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
while (!buf->eoe() && !buf->eof()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}
while (buf->eof() == false) {
if (take_count_ == max_takes_) {
// Do drain Operation
while (!buf->eoe() && !buf->eof()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
}
} else if (state_ == OpState::kDeOpIdle) {
MS_LOG(DEBUG) << "Meet max count and push-back eof buffer.";
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
*p_buffer = std::move(eof_buffer);
}
// Loop until non EOE is received
if (buf->eoe()) {
take_count_ = 0;
} else {
MS_LOG(WARNING) << "Invalid OpState: " << state_;
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
continue;
}
return Status::OK();
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
// Loop until non EOE is received
if (buf->eoe()) {
take_count_ = 0;
*p_buffer = std::move(buf);
return Status::OK();
}
// Check if the last buf is next eof
if (buf->eof()) {
*p_buffer = std::move(buf);
return Status::OK();
// Get buffer and push back when take_count is still small
if (take_count_ < max_takes_) {
std::unique_ptr<DataBuffer> p_buffer;
RETURN_IF_NOT_OK(FillBuffer(&buf, &p_buffer));
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(p_buffer)));
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
}
// Get buffer and push back when take_count is still small
if (take_count_ < max_takes_) {
RETURN_IF_NOT_OK(FillBuffer(&buf, p_buffer));
}
take_count_ = 0;
MS_LOG(DEBUG) << "Meet the end and push-back eof buffer.";
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
return Status::OK();
}
@ -139,13 +126,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
return Status::OK();
}
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. TakeOp is an inlined operator."); }
Status TakeOp::PrepareNodePostAction() {
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
tree_->AddToRepeatStack(shared_from_this());

@ -45,6 +45,7 @@ class TakeOp : public PipelineOp {
private:
int32_t build_max_takes_;
int32_t builder_op_connector_size_;
Status SanityCheck() const;
};
@ -52,7 +53,7 @@ class TakeOp : public PipelineOp {
// Constructor of the TakeOp.
// @note The builder class should be used to call it
// @param count - The number of takes to do
explicit TakeOp(int32_t count);
explicit TakeOp(int32_t count, int32_t op_connector_size);
// Destructor
~TakeOp() = default;
@ -72,23 +73,11 @@ class TakeOp : public PipelineOp {
return out;
}
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
Status operator()() override;
// Gets a buffer from the child node. The caller is typically our parent node.
// @note This function sets the `retryIfEoe` flag when popping from the child connector. This way,
// this function will retry to pop the connector again and will get the non-EOE buffer if any.
// @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first

@ -30,6 +30,12 @@ def generator_10():
yield np.array([i]),
def filter_func_ge(data):
if data > 3:
return False
return True
def test_take_01():
"""
Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof
@ -297,6 +303,44 @@ def test_take_16():
assert sum([1 for _ in data1]) == 5
def test_take_17():
"""
Test take: take first, then do fiter operation
"""
logger.info("test_take_17")
data1 = ds.GeneratorDataset(generator_10, ["data"])
data1 = data1.take(8)
data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert i == d[0][0]
assert sum([1 for _ in data1]) == 4
def test_take_18():
"""
Test take: take first, then do fiter, skip, batch and repeat operation
"""
logger.info("test_take_18")
data1 = ds.GeneratorDataset(generator_10, ["data"])
data1 = data1.take(8)
data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
data1 = data1.skip(2)
data1 = data1.batch(2)
data1 = data1.repeat(2)
# Here i refers to index, d refers to data element
for i, d in enumerate(data1):
assert 2 == d[0][0]
assert sum([1 for _ in data1]) == 2
if __name__ == '__main__':
test_take_01()
test_take_02()
@ -314,4 +358,6 @@ if __name__ == '__main__':
test_take_14()
test_take_15()
test_take_16()
test_take_17()
test_take_18()
logger.info('== test take operation finished ==')
Loading…
Cancel
Save