!856 Fix skip op bug

Merge pull request !856 from jiangzhiwen/dataset/skip_thread
pull/856/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 98939d839c

@ -16,6 +16,7 @@
#include <iostream> #include <iostream>
#include <utility> #include <utility>
#include "dataset/core/config_manager.h"
#include "dataset/engine/data_buffer.h" #include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/skip_op.h" #include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/db_connector.h" #include "dataset/engine/db_connector.h"
@ -26,7 +27,10 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Builder constructor. Creates the builder object. // Builder constructor. Creates the builder object.
SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {} SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_op_connector_size_ = cfg->op_connector_size();
}
Status SkipOp::Builder::SanityCheck() const { Status SkipOp::Builder::SanityCheck() const {
if (build_max_skips_ < 0) { if (build_max_skips_ < 0) {
@ -39,12 +43,13 @@ Status SkipOp::Builder::SanityCheck() const {
// The builder "build" method creates the final object. // The builder "build" method creates the final object.
Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) { Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck()); RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<SkipOp>(build_max_skips_); *ptr = std::make_shared<SkipOp>(build_max_skips_, builder_op_connector_size_);
return Status::OK(); return Status::OK();
} }
// Constructor of the SkipOp. // Constructor of the SkipOp.
SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {} SkipOp::SkipOp(int32_t count, int32_t op_connector_size)
: PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {}
// Destructor // Destructor
SkipOp::~SkipOp() {} SkipOp::~SkipOp() {}
@ -59,49 +64,6 @@ void SkipOp::Print(std::ostream &out, bool show_all) const {
<< "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_; << "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_;
} }
// Since the buffer may contain multi rows, this function will drop the rows
// that need to skip in it, and then return the buffer.
Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
if (child_.empty()) {
RETURN_STATUS_UNEXPECTED("SkipOp can't be the leaf node.");
}
std::unique_ptr<DataBuffer> buf;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
// Drop first max_skips_ rows
while (skip_count_ < max_skips_) {
if (buf->eoe() || buf->eof()) {
break;
}
// Consider the rows of buffer more than 1
TensorRow drop_row;
int row_num = buf->NumRows();
int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
skip_count_ += drop_num;
for (int i = 0; i < drop_num; i++) {
RETURN_IF_NOT_OK(buf->PopRow(&drop_row));
}
if (buf->NumRows() == 0) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}
}
// Handling eoe
if (buf->eoe()) {
RETURN_IF_NOT_OK(EoeReceived(worker_id));
}
// Handling eof
if (buf->eof()) {
RETURN_IF_NOT_OK(EofReceived(worker_id));
}
*p_buffer = std::move(buf);
return Status::OK();
}
// Base-class override for handling cases when an eoe is received. // Base-class override for handling cases when an eoe is received.
Status SkipOp::EoeReceived(int32_t worker_id) { Status SkipOp::EoeReceived(int32_t worker_id) {
skip_count_ = 0; skip_count_ = 0;
@ -109,13 +71,45 @@ Status SkipOp::EoeReceived(int32_t worker_id) {
return Status::OK(); return Status::OK();
} }
// Class functor operator () override. // main entry point for skip
// Most dataset ops operate by launching a thread (see ExecutionTree). Status SkipOp::operator()() {
// However, the SkipOp is defined as a inlined operator, so it is invalid to TaskManager::FindMe()->Post();
// launch the functor since this op runs inlined inside another operator. The std::unique_ptr<DataBuffer> curr_buffer;
// function is overloaded to ensure that it is not called by mistake (it will RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
// generate an error). while (curr_buffer->eof() == false) {
Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); } // Reset count
skip_count_ = 0;
while (curr_buffer->eoe() == false) {
// Drop first count rows
while (skip_count_ < max_skips_) {
if (curr_buffer->eoe() || curr_buffer->eof()) {
break;
}
// Consider the rows of buffer more than one
TensorRow drop_row;
int row_num = curr_buffer->NumRows();
int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
skip_count_ += drop_num;
for (int i = 0; i < drop_num; i++) {
RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row));
}
if (curr_buffer->NumRows() == 0) {
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
}
}
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer)));
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
}
// we got eoe, now try again until we got eof
MS_LOG(DEBUG) << "Skip operator EOE Received.";
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
}
MS_LOG(DEBUG) << "Skip operator EOF Received.";
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
return Status::OK();
}
// Base-class override for handling cases when an eof is received. // Base-class override for handling cases when an eof is received.
Status SkipOp::EofReceived(int32_t worker_id) { Status SkipOp::EofReceived(int32_t worker_id) {

@ -42,6 +42,7 @@ class SkipOp : public PipelineOp {
private: private:
int32_t build_max_skips_; int32_t build_max_skips_;
int32_t builder_op_connector_size_;
Status SanityCheck() const; Status SanityCheck() const;
}; };
@ -49,7 +50,7 @@ class SkipOp : public PipelineOp {
// Constructor of the SkipOp. // Constructor of the SkipOp.
// @note The builder class should be used to call it // @note The builder class should be used to call it
// @param count - The number of skips to do // @param count - The number of skips to do
explicit SkipOp(int32_t count); explicit SkipOp(int32_t count, int32_t op_connector_size);
// Destructor // Destructor
~SkipOp(); ~SkipOp();
@ -60,23 +61,11 @@ class SkipOp : public PipelineOp {
void Print(std::ostream &out, bool show_all) const override; void Print(std::ostream &out, bool show_all) const override;
// Class functor operator () override. // Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree). // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// However, the SkipOp is defined as a inlined operator, so it is invalid to launch the // provide the master loop that drives the logic for performing the work
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// @return Status - The error code return // @return Status - The error code return
Status operator()() override; Status operator()() override;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since SkipOp is an inlined op, getting a buffer from us will simply bounce you to get
// a buffer from our child.
// @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
// Base-class override for handling cases when an eoe is received. // Base-class override for handling cases when an eoe is received.
// @param worker_id - The worker id // @param worker_id - The worker id
Status EoeReceived(int32_t worker_id) override; Status EoeReceived(int32_t worker_id) override;

@ -47,7 +47,7 @@ TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) {
ASSERT_TRUE(rc.IsOk()); ASSERT_TRUE(rc.IsOk());
// SkipOp // SkipOp
std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5); std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5, 2);
rc = my_tree->AssociateNode(skip_op); rc = my_tree->AssociateNode(skip_op);
ASSERT_TRUE(rc.IsOk()); ASSERT_TRUE(rc.IsOk());

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import numpy as np import numpy as np
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
@ -51,7 +50,7 @@ def generator_md():
def test_generator_skip(): def test_generator_skip():
ds1 = ds.GeneratorDataset(generator_md, ["data"]) ds1 = ds.GeneratorDataset(generator_md, ["data"], num_parallel_workers=4)
# Here ds1 should be [3, 4] # Here ds1 should be [3, 4]
ds1 = ds1.skip(3) ds1 = ds1.skip(3)
@ -60,6 +59,7 @@ def test_generator_skip():
for data in ds1: for data in ds1:
buf.append(data[0][0]) buf.append(data[0][0])
assert len(buf) == 2 assert len(buf) == 2
assert buf == [3, 4]
def test_skip_1(): def test_skip_1():
@ -72,6 +72,7 @@ def test_skip_1():
for data in ds1: for data in ds1:
buf.append(data[0][0]) buf.append(data[0][0])
assert len(buf) == 0 assert len(buf) == 0
assert buf == []
def test_skip_2(): def test_skip_2():
@ -84,6 +85,7 @@ def test_skip_2():
for data in ds1: for data in ds1:
buf.append(data[0][0]) buf.append(data[0][0])
assert len(buf) == 5 assert len(buf) == 5
assert buf == [0, 1, 2, 3, 4]
def test_skip_repeat_1(): def test_skip_repeat_1():
@ -99,6 +101,7 @@ def test_skip_repeat_1():
for data in ds1: for data in ds1:
buf.append(data[0][0]) buf.append(data[0][0])
assert len(buf) == 7 assert len(buf) == 7
assert buf == [3, 4, 0, 1, 2, 3, 4]
def test_skip_repeat_2(): def test_skip_repeat_2():
@ -114,6 +117,7 @@ def test_skip_repeat_2():
for data in ds1: for data in ds1:
buf.append(data[0][0]) buf.append(data[0][0])
assert len(buf) == 4 assert len(buf) == 4
assert buf == [3, 4, 3, 4]
def test_skip_repeat_3(): def test_skip_repeat_3():
@ -132,6 +136,62 @@ def test_skip_repeat_3():
for data in ds1: for data in ds1:
buf.append(data[0][0]) buf.append(data[0][0])
assert len(buf) == 6 assert len(buf) == 6
assert buf == [3, 4, 3, 4, 3, 4]
def test_skip_take_1():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [0, 1, 2, 3]
ds1 = ds1.take(4)
# Here ds1 should be [2, 3]
ds1 = ds1.skip(2)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 2
assert buf == [2, 3]
def test_skip_take_2():
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [2, 3, 4]
ds1 = ds1.skip(2)
# Here ds1 should be [2, 3]
ds1 = ds1.take(2)
buf = []
for data in ds1:
buf.append(data[0][0])
assert len(buf) == 2
assert buf == [2, 3]
def generator_1d():
for i in range(64):
yield (np.array([i]), )
def test_skip_filter_1():
dataset = ds.GeneratorDataset(generator_1d, ['data'])
dataset = dataset.skip(5)
dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
buf = []
for item in dataset:
buf.append(item[0][0])
assert buf == [5, 6, 7, 8, 9, 10]
def test_skip_filter_2():
dataset = ds.GeneratorDataset(generator_1d, ['data'])
dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
dataset = dataset.skip(5)
buf = []
for item in dataset:
buf.append(item[0][0])
assert buf == [5, 6, 7, 8, 9, 10]
if __name__ == "__main__": if __name__ == "__main__":
@ -142,3 +202,7 @@ if __name__ == "__main__":
test_skip_repeat_1() test_skip_repeat_1()
test_skip_repeat_2() test_skip_repeat_2()
test_skip_repeat_3() test_skip_repeat_3()
test_skip_take_1()
test_skip_take_2()
test_skip_filter_1()
test_skip_filter_2()

Loading…
Cancel
Save