Realize take op and add ut

pull/330/head
ms_yan 5 years ago
parent 71b63c3fcf
commit f0c07c3fa6

@ -54,6 +54,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kGenerator, &DEPipeline::ParseGeneratorOp}, {kGenerator, &DEPipeline::ParseGeneratorOp},
{kTfReader, &DEPipeline::ParseTFReaderOp}, {kTfReader, &DEPipeline::ParseTFReaderOp},
{kProject, &DEPipeline::ParseProjectOp}, {kProject, &DEPipeline::ParseProjectOp},
{kTake, &DEPipeline::ParseTakeOp},
{kImageFolder, &DEPipeline::ParseImageFolderOp}, {kImageFolder, &DEPipeline::ParseImageFolderOp},
{kMnist, &DEPipeline::ParseMnistOp}, {kMnist, &DEPipeline::ParseMnistOp},
{kManifest, &DEPipeline::ParseManifestOp}, {kManifest, &DEPipeline::ParseManifestOp},
@ -650,7 +651,16 @@ Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp
return Status::OK(); return Status::OK();
} }
DsOpPtr DEPipeline::ParseTakeOp(const py::dict &args) const { return DsOpPtr(); } Status DEPipeline::ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
if (args["count"].is_none()) {
std::string err_msg = "Error: count is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::shared_ptr<TakeOp> op;
RETURN_IF_NOT_OK(TakeOp::Builder(ToInt(args["count"])).Build(&op));
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<ZipOp::Builder> builder = std::make_shared<ZipOp::Builder>(); std::shared_ptr<ZipOp::Builder> builder = std::make_shared<ZipOp::Builder>();

@ -116,7 +116,7 @@ class DEPipeline {
Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
DsOpPtr ParseTakeOp(const py::dict &args) const; Status ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

@ -38,6 +38,7 @@
#include "dataset/engine/datasetops/source/mindrecord_op.h" #include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/storage_op.h" #include "dataset/engine/datasetops/source/storage_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h" #include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/datasetops/zip_op.h" #include "dataset/engine/datasetops/zip_op.h"
#include "dataset/engine/execution_tree.h" #include "dataset/engine/execution_tree.h"
#include "dataset/util/status.h" #include "dataset/util/status.h"

@ -5,13 +5,13 @@ add_library(engine-datasetops OBJECT
parallel_op.cc parallel_op.cc
pipeline_op.cc pipeline_op.cc
batch_op.cc batch_op.cc
batch_op.cc
device_queue_op.cc device_queue_op.cc
map_op.cc map_op.cc
project_op.cc project_op.cc
rename_op.cc rename_op.cc
repeat_op.cc repeat_op.cc
skip_op.cc skip_op.cc
take_op.cc
shuffle_op.cc shuffle_op.cc
zip_op.cc zip_op.cc
) )

@ -88,6 +88,10 @@ Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work
// If buffer is none or the rows of buffer is 0, // If buffer is none or the rows of buffer is 0,
// then get a buffer from child. // then get a buffer from child.
if (!buf || buf->NumRows() == 0) { if (!buf || buf->NumRows() == 0) {
if (buf && buf->eof()) {
*p_buffer = std::move(buf);
return Status::OK();
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
} }

@ -0,0 +1,146 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>
#include "common/utils.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {}
Status TakeOp::Builder::SanityCheck() const {
if (build_max_takes_ <= 0) {
std::string err_msg("Take count must be greater than 0.");
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}
// The builder "build" method creates the final object.
Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<TakeOp>(build_max_takes_);
return Status::OK();
}
// Constructor of the TakeOp.
TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {}
// A print method typically used for debugging
void TakeOp::Print(std::ostream &out, bool show_all) const {
// Call base class printer first
PipelineOp::Print(out, show_all);
// Then display our own stuff
out << "TakeOp:"
<< "\nCurrent take count: " << take_count_ << "\nMax take count: " << max_takes_;
}
// This function will be call muti times to returns the buffer, when meet required max take count or meet
// EOF buffer then this will stop.
Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
if (child_.empty()) {
RETURN_STATUS_UNEXPECTED("TakeOp can't be the leaf node.");
}
std::unique_ptr<DataBuffer> buf;
bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat);
if (take_count_ == max_takes_) {
if (state_ == OpState::kDeOpRunning) {
MS_LOG(INFO) << "meet max count and push-back eoe buffer.";
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
*p_buffer = std::move(eoe_buffer);
state_ = OpState::kDeOpIdle;
// Reset the count and drain
if (!last_repeat) {
take_count_ = 0;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
while (!buf->eoe() && !buf->eof()) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
}
}
} else {
MS_LOG(INFO) << "meet max count and push-back eof buffer.";
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
*p_buffer = std::move(eof_buffer);
take_count_ = 0;
}
return Status::OK();
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
// Loop until non EOE is received
if (buf->eoe()) {
take_count_ = 0;
*p_buffer = std::move(buf);
return Status::OK();
}
// Check if the last buf is next eof
if (buf->eof()) {
*p_buffer = std::move(buf);
return Status::OK();
}
// Get buffer and push back when take_count is still small
if (take_count_ < max_takes_) {
RETURN_IF_NOT_OK(FillBuffer(&buf, p_buffer));
}
return Status::OK();
}
// Function FillBuffer mainly prepare the buffer for returning
Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<DataBuffer> *data_buffer) {
int32_t buffer_size = (*buffer)->NumRows();
if (take_count_ + buffer_size < max_takes_) {
*data_buffer = std::move(*buffer);
take_count_ = take_count_ + buffer_size;
} else {
MS_LOG(INFO) << "In last buffer: Push one buffer.";
std::unique_ptr<TensorQTable> new_tensor_table = std::make_unique<TensorQTable>();
while (take_count_ < max_takes_) {
TensorRow new_row;
RETURN_IF_NOT_OK((*buffer)->PopRow(&new_row));
take_count_++;
new_tensor_table->push_back(new_row);
}
(*buffer)->set_tensor_table(std::move(new_tensor_table));
*data_buffer = std::move(*buffer);
}
return Status::OK();
}
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. TakeOp is an inlined operator."); }
Status TakeOp::PrepareNodePostAction() {
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
tree_->AddToRepeatStack(shared_from_this());
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,107 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_TAKE_OP_H_
#define DATASET_ENGINE_DATASETOPS_TAKE_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "dataset/engine/datasetops/pipeline_op.h"
namespace mindspore {
namespace dataset {
class TakeOp : public PipelineOp {
public:
// The nested builder class inside of the TakeOp is used to help manage all of the arguments
// for constructing it. This take op is very simple though, so this builder is really just
// provided for a consistent look and feel for creators of Dataset operators overall.
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @param count - The number of takes to do
// @return This is a constructor.
explicit Builder(int32_t count);
// Default destructor
~Builder() = default;
// The builder "build" method creates the final object.
// @return shared_ptr to the new StorageOp object
Status Build(std::shared_ptr<TakeOp> *);
private:
int32_t build_max_takes_;
Status SanityCheck() const;
};
// Constructor of the TakeOp.
// @note The builder class should be used to call it
// @param count - The number of takes to do
explicit TakeOp(int32_t count);
// Destructor
~TakeOp() = default;
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
// << Stream output operator overload
// @notes This allows you to write the debug print info using stream operators
// @param out - reference to the output stream being overloaded
// @param ro - reference to the TakeOp to display
// @return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const TakeOp &ro) {
ro.Print(out, false);
return out;
}
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// @return Status - The error code return
Status operator()() override;
// Gets a buffer from the child node. The caller is typically our parent node.
// @note This function sets the `retryIfEoe` flag when popping from the child connector. This way,
// this function will retry to pop the connector again and will get the non-EOE buffer if any.
// @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status PrepareNodePostAction() override;
private:
int32_t max_takes_; // The number of takes that the user requested
int32_t take_count_; // A counter for the current number of executed takes
Status FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<DataBuffer> *data_buffer);
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_TAKE_OP_H_

@ -36,7 +36,7 @@ from mindspore import log as logger
from . import samplers from . import samplers
from .iterators import DictIterator, TupleIterator from .iterators import DictIterator, TupleIterator
from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \
check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
check_zip_dataset, check_add_column check_zip_dataset, check_add_column
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
@ -442,6 +442,33 @@ class Dataset:
""" """
return SkipDataset(self, count) return SkipDataset(self, count)
@check_take
def take(self, count=-1):
"""
Takes at most given numbers of elements from the dataset.
Note:
1. If count is greater than the number of element in dataset or equal to -1,
all the element in dataset will be taken.
2. The order of using take and batch effects. If take before batch operation,
then taken given number of rows, otherwise take given number of batches.
Args:
count (int, optional): Number of elements to be taken from the dataset (default=-1).
Returns:
TakeDataset, dataset taken.
Examples:
>>> import mindspore.dataset as ds
>>> # data is an instance of Dataset object.
>>> # creates a dataset where the dataset including 50 elements.
>>> data = data.take(50)
"""
if count == -1:
return self
return TakeDataset(self, count)
@check_zip_dataset @check_zip_dataset
def zip(self, datasets): def zip(self, datasets):
""" """
@ -1100,6 +1127,7 @@ class RepeatDataset(DatasetOp):
""" """
return self.count return self.count
class SkipDataset(DatasetOp): class SkipDataset(DatasetOp):
""" """
The result of applying Skip operator to the input Dataset. The result of applying Skip operator to the input Dataset.
@ -1134,6 +1162,41 @@ class SkipDataset(DatasetOp):
output_size = child_size - self.count output_size = child_size - self.count
return output_size return output_size
class TakeDataset(DatasetOp):
"""
The result of applying Take operator to the input Dataset.
Args:
input_dataset (Dataset): Input Dataset to be taken element from.
count (int): Number of elements to be taken from the dataset.
"""
def __init__(self, input_dataset, count):
super().__init__()
self.count = count
self.input.append(input_dataset)
input_dataset.output.append(self)
self._input_indexs = input_dataset.input_indexs
def get_args(self):
args = super().get_args()
args["count"] = self.count
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
child_size = self.input[0].get_dataset_size()
if child_size < self.count:
return child_size
return self.count
class ZipDataset(DatasetOp): class ZipDataset(DatasetOp):
""" """
The result of applying Zip operator to the input Dataset. The result of applying Zip operator to the input Dataset.

@ -129,6 +129,8 @@ class Iterator:
op_type = OpName.REPEAT op_type = OpName.REPEAT
elif isinstance(dataset, de.SkipDataset): elif isinstance(dataset, de.SkipDataset):
op_type = OpName.SKIP op_type = OpName.SKIP
elif isinstance(dataset, de.TakeDataset):
op_type = OpName.TAKE
elif isinstance(dataset, de.StorageDataset): elif isinstance(dataset, de.StorageDataset):
op_type = OpName.STORAGE op_type = OpName.STORAGE
elif isinstance(dataset, de.ImageFolderDatasetV2): elif isinstance(dataset, de.ImageFolderDatasetV2):

@ -304,6 +304,9 @@ def create_node(node):
elif dataset_op == 'SkipDataset': elif dataset_op == 'SkipDataset':
pyobj = de.Dataset().skip(node.get('count')) pyobj = de.Dataset().skip(node.get('count'))
elif dataset_op == 'TakeDataset':
pyobj = de.Dataset().take(node.get('count'))
elif dataset_op == 'MapDataset': elif dataset_op == 'MapDataset':
tensor_ops = construct_tensor_ops(node.get('operations')) tensor_ops = construct_tensor_ops(node.get('operations'))
pyobj = de.Dataset().map(node.get('input_columns'), tensor_ops, node.get('output_columns'), pyobj = de.Dataset().map(node.get('input_columns'), tensor_ops, node.get('output_columns'),

@ -602,7 +602,7 @@ def check_batch_size(batch_size):
def check_count(count): def check_count(count):
check_type(count, 'count', int) check_type(count, 'count', int)
if (count <= 0 and count != -1) or count > INT32_MAX: if (count <= 0 and count != -1) or count > INT32_MAX:
raise ValueError("repeat count should be either -1 or positive integer.") raise ValueError("count should be either -1 or positive integer.")
def check_columns(columns, name): def check_columns(columns, name):
@ -709,6 +709,7 @@ def check_repeat(method):
return new_method return new_method
def check_skip(method): def check_skip(method):
"""check the input arguments of skip.""" """check the input arguments of skip."""
@wraps(method) @wraps(method)
@ -724,6 +725,21 @@ def check_skip(method):
return new_method return new_method
def check_take(method):
"""check the input arguments of take."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
count = param_dict.get('count')
check_count(count)
return method(*args, **kwargs)
return new_method
def check_zip(method): def check_zip(method):
"""check the input arguments of zip.""" """check the input arguments of zip."""
@wraps(method) @wraps(method)
@ -759,6 +775,7 @@ def check_zip_dataset(method):
return new_method return new_method
def check_rename(method): def check_rename(method):
"""check the input arguments of rename.""" """check the input arguments of rename."""
@wraps(method) @wraps(method)

@ -64,6 +64,7 @@ SET(DE_UT_SRCS
voc_op_test.cc voc_op_test.cc
cifar_op_test.cc cifar_op_test.cc
celeba_op_test.cc celeba_op_test.cc
take_op_test.cc
) )
add_executable(de_ut_tests ${DE_UT_SRCS}) add_executable(de_ut_tests ${DE_UT_SRCS})

@ -0,0 +1,103 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <vector>
#include "common/common.h"
#include "common/utils.h"
#include "dataset/core/client.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
namespace common = mindspore::common;
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestTakeOp : public UT::DatasetOpTesting {};
TEST_F(MindDataTestTakeOp, TestTakeProject) {
// Start with an empty execution tree
auto my_tree = std::make_shared<ExecutionTree>();
std::string dataset_path;
dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data";
// TFReaderOp
std::shared_ptr<TFReaderOp> my_tfreader_op;
TFReaderOp::Builder builder;
builder.SetDatasetFilesList({dataset_path})
.SetRowsPerBuffer(16)
.SetWorkerConnectorSize(16)
.SetNumWorkers(16);
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
schema->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {});
builder.SetDataSchema(std::move(schema));
Status rc = builder.Build(&my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
// TakeOp
std::shared_ptr<TakeOp> my_take_op;
TakeOp::Builder builder_take(5);
rc = builder_take.Build(&my_take_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_take_op);
ASSERT_TRUE(rc.IsOk());
// Set children/root layout.
rc = my_take_op->AddChild(my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(my_take_op);
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration.";
rc = my_tree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = my_tree->Launch();
ASSERT_TRUE(rc.IsOk());
// Start the loop of reading tensors from our pipeline
DatasetIterator di(my_tree);
TensorRow tensor_list;
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
int row_count = 0;
while (!tensor_list.empty()) {
MS_LOG(INFO) << "Row display for row #: " << row_count << ".";
// Display the tensor by calling the printer on it
for (int i = 0; i < tensor_list.size(); i++) {
std::ostringstream ss;
ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl;
MS_LOG(INFO) << "Tensor print: " << ss.str() << ".";
}
rc = di.FetchNextTensorRow(&tensor_list);
ASSERT_TRUE(rc.IsOk());
row_count++;
}
ASSERT_EQ(row_count, 5);
}

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save