initial commit, start of BucketBatchByLengthOp

c implementation done, just need to call batch/pad

added python api and validator

added pybind/de_pipeline stuff, fixed some compile errors, figure out how null py::function works

added tiny bit of doc

integrated with static batch methods

fixed some bugs

some more bug fixes and cleanup

ci fix

fix ci

ci fix

fix ci

added test_cases and debugged

addressed code review comments

addressed code review comments

ci fix

ci fix

addressed code review comments

addressed code review comments
pull/1984/head
Peilin Wang 5 years ago
parent ff0590315c
commit 848e07d022

@ -19,62 +19,65 @@
#include <map>
#include "common/utils.h"
#include "dataset/kernels/py_func_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/bucket_batch_by_length_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/kernels/py_func_op.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_distributed_sample.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_shuffle.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"
#include "utils/log_adapter.h"
#include "pybind11/stl.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *);
static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &DEPipeline::ParseStorageOp},
{kShuffle, &DEPipeline::ParseShuffleOp},
{kMindrecord, &DEPipeline::ParseMindRecordOp},
{kMap, &DEPipeline::ParseMapOp},
{kFilter, &DEPipeline::ParseFilterOp},
{kBatch, &DEPipeline::ParseBatchOp},
{kBarrier, &DEPipeline::ParseBarrierOp},
{kRepeat, &DEPipeline::ParseRepeatOp},
{kSkip, &DEPipeline::ParseSkipOp},
{kZip, &DEPipeline::ParseZipOp},
{kConcat, &DEPipeline::ParseConcatOp},
{kRename, &DEPipeline::ParseRenameOp},
{kDeviceQueue, &DEPipeline::ParseDeviceQueueOp},
{kGenerator, &DEPipeline::ParseGeneratorOp},
{kTfReader, &DEPipeline::ParseTFReaderOp},
{kProject, &DEPipeline::ParseProjectOp},
{kTake, &DEPipeline::ParseTakeOp},
{kImageFolder, &DEPipeline::ParseImageFolderOp},
{kMnist, &DEPipeline::ParseMnistOp},
{kManifest, &DEPipeline::ParseManifestOp},
{kVoc, &DEPipeline::ParseVOCOp},
{kCoco, &DEPipeline::ParseCocoOp},
{kCifar10, &DEPipeline::ParseCifar10Op},
{kCifar100, &DEPipeline::ParseCifar100Op},
{kCelebA, &DEPipeline::ParseCelebAOp},
{kRandomData, &DEPipeline::ParseRandomDataOp},
{kTextFile, &DEPipeline::ParseTextFileOp},
{kBuildVocab, &DEPipeline::ParseBuildVocabOp},
{kClue, &DEPipeline::ParseClueOp}};
static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
{kStorage, &DEPipeline::ParseStorageOp},
{kShuffle, &DEPipeline::ParseShuffleOp},
{kMindrecord, &DEPipeline::ParseMindRecordOp},
{kMap, &DEPipeline::ParseMapOp},
{kFilter, &DEPipeline::ParseFilterOp},
{kBatch, &DEPipeline::ParseBatchOp},
{kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp},
{kBarrier, &DEPipeline::ParseBarrierOp},
{kRepeat, &DEPipeline::ParseRepeatOp},
{kSkip, &DEPipeline::ParseSkipOp},
{kZip, &DEPipeline::ParseZipOp},
{kConcat, &DEPipeline::ParseConcatOp},
{kRename, &DEPipeline::ParseRenameOp},
{kDeviceQueue, &DEPipeline::ParseDeviceQueueOp},
{kGenerator, &DEPipeline::ParseGeneratorOp},
{kTfReader, &DEPipeline::ParseTFReaderOp},
{kProject, &DEPipeline::ParseProjectOp},
{kTake, &DEPipeline::ParseTakeOp},
{kImageFolder, &DEPipeline::ParseImageFolderOp},
{kMnist, &DEPipeline::ParseMnistOp},
{kManifest, &DEPipeline::ParseManifestOp},
{kVoc, &DEPipeline::ParseVOCOp},
{kCoco, &DEPipeline::ParseCocoOp},
{kCifar10, &DEPipeline::ParseCifar10Op},
{kCifar100, &DEPipeline::ParseCifar100Op},
{kCelebA, &DEPipeline::ParseCelebAOp},
{kRandomData, &DEPipeline::ParseRandomDataOp},
{kTextFile, &DEPipeline::ParseTextFileOp},
{kBuildVocab, &DEPipeline::ParseBuildVocabOp},
{kClue, &DEPipeline::ParseClueOp}};
DEPipeline::DEPipeline() : iterator_(nullptr) {
try {
@ -672,6 +675,56 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
return Status::OK();
}
Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::vector<std::string> mandatory_arguments = {"length_dependent_columns", "bucket_boundaries",
"bucket_batch_sizes"};
for (auto name : mandatory_arguments) {
if (args[name.c_str()].is_none()) {
std::string err_msg = "Error: " + name + " is not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
std::shared_ptr<BucketBatchByLengthOp::Builder> builder = std::make_shared<BucketBatchByLengthOp::Builder>(
ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]),
ToIntVector(args[mandatory_arguments[2].c_str()]));
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "length_dependent_columns") {
(void)builder->SetLengthDependentColumns(ToStringVector(value));
}
if (key == "bucket_boundaries") {
(void)builder->SetBucketBoundaries(ToIntVector(value));
}
if (key == "bucket_batch_sizes") {
(void)builder->SetBucketBatchSizes(ToIntVector(value));
}
if (key == "element_length_function") {
(void)builder->SetElementLengthFunction(value.cast<py::function>());
}
if (key == "pad_info") {
PadInfo pad_info;
RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info));
(void)builder->SetPadInfo(pad_info);
}
if (key == "pad_to_bucket_boundary") {
(void)builder->SetPadToBucketBoundary(ToBool(value));
}
if (key == "drop_remainder") {
(void)builder->SetDropRemainder(ToBool(value));
}
}
}
std::shared_ptr<BucketBatchByLengthOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>();
// Right now barrier should only take num_rows_per_buffer = 1

@ -40,6 +40,7 @@ enum OpName {
kShuffle,
kMindrecord,
kBatch,
kBucketBatch,
kBarrier,
kCache,
kRepeat,
@ -121,6 +122,8 @@ class DEPipeline {
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

@ -616,6 +616,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("STORAGE", OpName::kStorage)
.value("SHUFFLE", OpName::kShuffle)
.value("BATCH", OpName::kBatch)
.value("BUCKETBATCH", OpName::kBucketBatch)
.value("BARRIER", OpName::kBarrier)
.value("MINDRECORD", OpName::kMindrecord)
.value("CACHE", OpName::kCache)

@ -8,6 +8,7 @@ add_library(engine-datasetops OBJECT
pipeline_op.cc
barrier_op.cc
batch_op.cc
bucket_batch_by_length_op.cc
device_queue_op.cc
map_op.cc
project_op.cc

@ -193,6 +193,22 @@ class BatchOp : public ParallelOp {
// @return Name of the current Op
std::string Name() const override { return "BatchOp"; }
// batch the rows in src table then put it to dest table
// @param const std::unique_ptr<TensorQTable> *src - table that has the rows for batching
// @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows
// @param int32_t size - batch_size
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
// @return Status - The error code return
static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest,
dsize_t batch_size);
// @param table
// @param const PadInfo &pad_info pad info
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
// @return Status - The error code return
static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
const std::unordered_map<std::string, int32_t> &column_name_id_map);
private:
// Worker thread for doing the memcpy of batch
// @param int32_t param workerId
@ -203,16 +219,6 @@ class BatchOp : public ParallelOp {
// @return Status - The error code return
Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair,
std::unique_ptr<DataBuffer> *db);
// batch the rows in src table then put it to dest table
// @param const std::unique_ptr<TensorQTable> *src - table that has the rows for batching
// @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows
// @param int32_t size - batch_size
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
// @return Status - The error code return
static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest,
dsize_t batch_size);
// Function that calls pyfunc to perform map on batch
// @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor
// @return Status - The error code return
@ -229,13 +235,6 @@ class BatchOp : public ParallelOp {
std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals,
std::vector<std::vector<dsize_t>> *pad_shapes);
// @param table
// @param const PadInfo &pad_info pad info
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
// @return Status - The error code return
static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
const std::unordered_map<std::string, int32_t> &column_name_id_map);
// the number of thread pulling from the mOutConnector of the Op below
// @return int32_t, 1
int32_t num_consumers() const override { return 1; }

@ -0,0 +1,242 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/bucket_batch_by_length_op.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "dataset/core/pybind_support.h"
#include "dataset/core/config_manager.h"
#include "dataset/core/tensor.h"
#include "dataset/core/tensor_shape.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/util/status.h"
namespace py = pybind11;
namespace mindspore {
namespace dataset {
BucketBatchByLengthOp::Builder::Builder(std::vector<std::string> length_dependent_columns,
std::vector<int32_t> bucket_boundaries, std::vector<int32_t> bucket_batch_sizes)
: builder_length_dependent_columns_(length_dependent_columns),
builder_bucket_boundaries_(bucket_boundaries),
builder_bucket_batch_sizes_(bucket_batch_sizes),
builder_pad_info_({}),
builder_pad_to_bucket_boundary_(false),
builder_drop_remainder_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_op_connector_size_ = config_manager->op_connector_size();
}
Status BucketBatchByLengthOp::Builder::SanityCheck() {
std::string error_message;
if (builder_length_dependent_columns_.empty()) {
error_message += "At least 1 column must be specified for element length calculation.\n";
}
if (builder_bucket_boundaries_.empty()) {
error_message += "At least 1 bucket boundary must be specified.\n";
}
if (builder_bucket_batch_sizes_.size() != builder_bucket_boundaries_.size() + 1) {
error_message += "There must be exactly one bucket batch size specified for each bucket boundary.\n";
}
CHECK_FAIL_RETURN_UNEXPECTED(error_message.empty(), error_message);
return Status::OK();
}
Status BucketBatchByLengthOp::Builder::Build(std::shared_ptr<BucketBatchByLengthOp> *new_bucket_batch_by_length_op) {
RETURN_IF_NOT_OK(SanityCheck());
// insert 0 for the first bucket
builder_bucket_boundaries_.insert(builder_bucket_boundaries_.begin(), 0);
*new_bucket_batch_by_length_op = std::make_shared<BucketBatchByLengthOp>(
builder_length_dependent_columns_, builder_bucket_boundaries_, builder_bucket_batch_sizes_,
builder_element_length_function_, builder_pad_info_, builder_pad_to_bucket_boundary_, builder_drop_remainder_,
builder_op_connector_size_);
return Status::OK();
}
BucketBatchByLengthOp::BucketBatchByLengthOp(std::vector<std::string> length_dependent_columns,
std::vector<int32_t> bucket_boundaries,
std::vector<int32_t> bucket_batch_sizes,
py::function element_length_function, PadInfo pad_info,
bool pad_to_bucket_boundary, bool drop_remainder,
int32_t op_connector_size)
: PipelineOp(op_connector_size),
length_dependent_columns_(length_dependent_columns),
bucket_boundaries_(bucket_boundaries),
bucket_batch_sizes_(bucket_batch_sizes),
element_length_function_(element_length_function),
pad_info_(pad_info),
pad_to_bucket_boundary_(pad_to_bucket_boundary),
drop_remainder_(drop_remainder),
batch_count_(0) {
for (int i = 0; i < bucket_batch_sizes_.size(); i++) {
buckets_.push_back(std::make_unique<TensorQTable>());
}
}
Status BucketBatchByLengthOp::EoeReceived(int32_t) {
state_ = OpState::kDeOpIdle;
return Status::OK();
}
void BucketBatchByLengthOp::Print(std::ostream &out, bool show_all) const { out << "BucketBatchByLengthOp\n"; }
Status BucketBatchByLengthOp::operator()() {
TaskManager::FindMe()->Post();
TensorRow current_row;
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&current_row));
RETURN_IF_NOT_OK(AssignColMapFromChild());
while (!child_iterator_->eof_handled()) {
while (!current_row.empty()) {
int32_t element_length;
RETURN_IF_NOT_OK(ObtainElementLength(&element_length, current_row));
int bucket_index = bucket_boundaries_.size() - 1;
while (element_length < bucket_boundaries_[bucket_index]) {
bucket_index--;
}
buckets_[bucket_index]->push_back(current_row);
if (buckets_[bucket_index]->size() == bucket_batch_sizes_[bucket_index]) {
RETURN_IF_NOT_OK(PadAndBatchBucket(bucket_index, bucket_batch_sizes_[bucket_index]));
}
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&current_row));
}
// got EOE, do what we need to do with remainders in each bucket
if (!drop_remainder_) {
for (int i = 0; i < bucket_boundaries_.size(); i++) {
if (!buckets_[i]->empty()) {
RETURN_IF_NOT_OK(PadAndBatchBucket(i, buckets_[i]->size()));
}
}
}
// need to send EOE manually since we set state to idle in EoeRecieved()
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&current_row));
}
return Status::OK();
}
Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, TensorRow element) {
// call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of
// the single column specified in length_dependent_columns_
if (element_length_function_) {
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
try {
size_t number_of_arguments = length_dependent_columns_.size();
py::tuple input_arguments(number_of_arguments);
for (size_t i = 0; i < number_of_arguments; i++) {
py::array argument_value;
int32_t column_index = column_name_id_map_[length_dependent_columns_[i]];
RETURN_IF_NOT_OK(element[column_index]->GetDataAsNumpy(&argument_value));
input_arguments[i] = argument_value;
}
py::object length = element_length_function_(*input_arguments);
*out_element_length = length.cast<int32_t>();
if (*out_element_length < 0) {
return Status(StatusCode::kPyFuncException, "Element length function should return a non negative integer.");
}
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
} catch (const py::cast_error &e) {
return Status(StatusCode::kPyFuncException, "Count not cast output of element length function to int32_t.");
}
} else {
*out_element_length = element[0]->shape()[0];
}
return Status::OK();
}
Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t batch_size) {
std::unique_ptr<TensorQTable> *bucket = &buckets_[bucket_index];
PadInfo pad_info_copy = pad_info_;
if (pad_to_bucket_boundary_) {
for (auto &pair : pad_info_copy) {
std::vector<dsize_t> pad_shape = pair.second.first.AsVector();
for (size_t i = 0; i < pad_shape.size(); i++) {
if (pad_shape[i] == TensorShape::kDimUnknown) {
if (bucket_index + 1 >= bucket_boundaries_.size()) {
std::string error_message = "Requested to pad to bucket boundary, element falls in last bucket";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, error_message);
}
pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1;
}
}
pair.second.first = TensorShape(pad_shape);
}
}
// PadColumns will change the data in bucket
RETURN_IF_NOT_OK(BatchOp::PadColumns(bucket, pad_info_copy, column_name_id_map_));
std::unique_ptr<TensorQTable> batched_bucket = std::make_unique<TensorQTable>();
RETURN_IF_NOT_OK(BatchOp::BatchRows(bucket, &batched_bucket, batch_size));
(*bucket)->clear();
std::unique_ptr<DataBuffer> batched_buffer = std::make_unique<DataBuffer>(batch_count_, DataBuffer::kDeBFlagNone);
batched_buffer->set_tensor_table(std::move(batched_bucket));
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(batched_buffer)));
batch_count_++;
return Status::OK();
}
Status BucketBatchByLengthOp::Reset() {
batch_count_ = 0;
for (int i = 0; i < buckets_.size(); i++) {
buckets_[i] = std::make_unique<TensorQTable>();
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,153 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_
#define DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_
#include <map>
#include <memory>
#include <queue>
#include <string>
#include <vector>
#include "dataset/core/config_manager.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/pipeline_op.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
class DataBuffer;
class BucketBatchByLengthOp : public PipelineOp {
public:
class Builder {
public:
Builder(std::vector<std::string> length_dependent_columns, std::vector<int32_t> bucket_boundaries,
std::vector<int32_t> bucket_batch_sizes);
~Builder() = default;
Builder &SetLengthDependentColumns(std::vector<std::string> length_dependent_columns) {
builder_length_dependent_columns_ = length_dependent_columns;
return *this;
}
Builder &SetBucketBoundaries(std::vector<int32_t> bucket_boundaries) {
builder_bucket_boundaries_ = bucket_boundaries;
return *this;
}
Builder &SetBucketBatchSizes(std::vector<int32_t> bucket_batch_sizes) {
builder_bucket_batch_sizes_ = bucket_batch_sizes;
return *this;
}
Builder &SetElementLengthFunction(py::function element_length_function) {
builder_element_length_function_ = element_length_function;
return *this;
}
Builder &SetPadInfo(PadInfo pad_info) {
builder_pad_info_ = pad_info;
return *this;
}
Builder &SetPadToBucketBoundary(bool pad_to_bucket_boundary) {
builder_pad_to_bucket_boundary_ = pad_to_bucket_boundary;
return *this;
}
Builder &SetDropRemainder(bool drop_remainder) {
builder_drop_remainder_ = drop_remainder;
return *this;
}
Builder &SetOpConnectorSize(int32_t op_connector_size) {
builder_op_connector_size_ = op_connector_size;
return *this;
}
Status Build(std::shared_ptr<BucketBatchByLengthOp> *new_bucket_batch_by_length_op);
private:
Status SanityCheck();
std::vector<std::string> builder_length_dependent_columns_;
std::vector<int32_t> builder_bucket_boundaries_;
std::vector<int32_t> builder_bucket_batch_sizes_;
py::function builder_element_length_function_;
PadInfo builder_pad_info_;
bool builder_pad_to_bucket_boundary_;
bool builder_drop_remainder_;
int32_t builder_op_connector_size_;
};
BucketBatchByLengthOp(std::vector<std::string> length_dependent_columns, std::vector<int32_t> bucket_boundaries,
std::vector<int32_t> bucket_batch_sizes, py::function element_length_function, PadInfo pad_info,
bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size);
// Might need to batch remaining buckets after receiving eoe, so override this method.
// @param int32_t workerId
// @return Status - The error code returned
Status EoeReceived(int32_t) override;
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
// << Stream output operator overload
// @notes This allows you to write the debug print info using stream operators
// @param out - reference to the output stream being overloaded
// @param sO - reference to the BucketBatchByLengthOp to display
// @return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const BucketBatchByLengthOp &bo) {
bo.Print(out, false);
return out;
}
// Main loop of batch
// @return Status - The error code returned
Status operator()() override;
// Function that is called by ResetOp at the end of every epoch
// @return Status - The error code returned
Status Reset() override;
private:
Status ObtainElementLength(int32_t *out_element_length, TensorRow element);
Status PadAndBatchBucket(int32_t bucket_index, int32_t batch_size);
std::vector<std::string> length_dependent_columns_;
std::vector<int32_t> bucket_boundaries_;
std::vector<int32_t> bucket_batch_sizes_;
py::function element_length_function_;
PadInfo pad_info_;
bool pad_to_bucket_boundary_;
bool drop_remainder_;
int32_t batch_count_;
std::unique_ptr<ChildIterator> child_iterator_;
std::vector<std::unique_ptr<TensorQTable>> buckets_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_

@ -42,9 +42,9 @@ from .iterators import DictIterator, TupleIterator
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, check_numpyslicesdataset, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_split, check_cluedataset
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset,\
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat,\
check_split, check_bucket_batch_by_length, check_cluedataset
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try:
@ -165,6 +165,76 @@ class Dataset:
args["num_parallel_workers"] = self.num_parallel_workers
return args
@check_bucket_batch_by_length
def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes,
element_length_function=None, pad_info=None,
pad_to_bucket_boundary=False, drop_remainder=False):
"""
Bucket elements according to their lengths, and pad and batch the buckets when
they are full.
A length function is called on each row in the dataset, the row is then
bucketed based on its length and bucket_boundaries. When a bucket reaches its
corresponding size specified in bucket_batch_sizes, the entire bucket will be
padded according to batch_info, and then batched. Each batch will be full,
except for maybe the last batch for each bucket.
Args:
column_names (list of string): Columns passed to element_length_function.
bucket_boundaries (list of int): A list consisting of the upper boundaries
of the buckets. Must be strictly increasing. If there are n boundaries,
n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one
bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each
0<i<n, and one bucket for [bucket_boundaries[n-1], inf).
bucket_batch_sizes (list of int): A list consisting of the batch sizes for
each buclet. Must contain len(bucket_boundaries)+1 elements.
element_length_function (Callable, optional): A function that takes in
len(column_names) arguments and returns an int. If no value is
provided, then len(column_names) must be 1, and the size of the first
dimension of that column will be taken as the length (default=None).
pad_info (dict, optional): Represents how to batch each column. The key
corresponds to the column name, the value must be a tuple of 2 elements.
The first element corresponds to the shape to pad to, and the second
element corresponds to the value to pad with. If a column is not
specified, then that column will be padded to the longest in the current
batch, and 0 will be used as the padding value. Any None dimensions will
be padded to the longest in the current batch, unless if
pad_to_bucket_boundary is True. If no padding is wanted, set pad_info
to None (default=None).
pad_to_bucket_boundary (bool, optional): If True, will pad each None
dimension in pad_info to the bucket_boundary minus 1. If there are any
elements that fall into the last bucket, an error will occur
(default=False).
drop_remainder (bool, optional): If True, will drop the last batch for each
bucket if it is not a full batch (default=False).
Examples:
>>> import mindspore.dataset as ds
>>> # data is an instance of Dataset object.
>>>
>>> # creates a dataset where every 100 rows is combined into a batch
>>> # and drops the last incomplete batch if there is one.
>>> column_names = ["col1", "col2"]
>>> buket_boundaries = [5, 10]
>>> bucket_batch_sizes = [5, 1, 1]
>>> element_length_function = (lambda col1, col2: max(len(col1), len(col2)))
>>>
>>> # will pad col1 to shape [2, bucket_boundaries[i]] where i is the
>>> # index of the bucket that is currently being batched.
>>> # will pad col2 to a shape where each dimension is the longest in all
>>> # the elements currently being batched.
>>> pad_info = {"col1", ([2, None], -1)}
>>> pad_to_bucket_boundary = True
>>>
>>> data = data.bucket_batch_by_length(column_names, bucket_boundaries,
>>> bucket_batch_sizes,
>>> element_length_function, pad_info),
>>> pad_to_bucket_boundary)
"""
return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes,
element_length_function, pad_info,
pad_to_bucket_boundary, drop_remainder)
@check_batch
def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
input_columns=None, pad_info=None):
@ -1400,6 +1470,47 @@ class DatasetOp(Dataset):
# No need for __init__ since it is the same as the super's init
class BucketBatchByLengthDataset(DatasetOp):
"""
The result of applying BucketBatchByLength operator to the input dataset.
"""
def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes,
element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder):
super().__init__()
self.column_names = column_names
self.bucket_boundaries = bucket_boundaries
self.bucket_batch_sizes = bucket_batch_sizes
self.element_length_function = element_length_function
self.pad_info = pad_info
self.pad_to_bucket_boundary = pad_to_bucket_boundary
self.drop_remainder = drop_remainder
self.input.append(input_dataset)
input_dataset.output.append(self)
self._input_indexs = input_dataset.input_indexs
def get_args(self):
args = super().get_args()
args["length_dependent_columns"] = self.column_names
args["bucket_boundaries"] = self.bucket_boundaries
args["bucket_batch_sizes"] = self.bucket_batch_sizes
args["element_length_function"] = self.element_length_function
args["pad_info"] = self.pad_info
args["pad_to_bucket_boundary"] = self.pad_to_bucket_boundary
args["drop_remainder"] = self.drop_remainder
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
return None
class BatchDataset(DatasetOp):
"""

@ -132,6 +132,8 @@ class Iterator:
op_type = OpName.MINDRECORD
elif isinstance(dataset, de.BatchDataset):
op_type = OpName.BATCH
elif isinstance(dataset, de.BucketBatchByLengthDataset):
op_type = OpName.BUCKETBATCH
elif isinstance(dataset, de.SyncWaitDataset):
op_type = OpName.BARRIER
elif isinstance(dataset, de.ZipDataset):

@ -752,6 +752,67 @@ def check_pad_info(key, val):
check_type(val[1], "pad_value", (int, float, str, bytes))
def check_bucket_batch_by_length(method):
"""check the input arguments of bucket_batch_by_length."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
check_param_type(nreq_param_list, param_dict, list)
# check column_names: must be list of string.
column_names = param_dict.get("column_names")
all_string = all(isinstance(item, str) for item in column_names)
if not all_string:
raise TypeError("column_names should be a list of str.")
element_length_function = param_dict.get("element_length_function")
if element_length_function is None and len(column_names) != 1:
raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")
# check bucket_boundaries: must be list of int, positive and strictly increasing
bucket_boundaries = param_dict.get('bucket_boundaries')
if not bucket_boundaries:
raise ValueError("bucket_boundaries cannot be empty.")
all_int = all(isinstance(item, int) for item in bucket_boundaries)
if not all_int:
raise TypeError("bucket_boundaries should be a list of int.")
all_non_negative = all(item >= 0 for item in bucket_boundaries)
if not all_non_negative:
raise ValueError("bucket_boundaries cannot contain any negative numbers.")
for i in range(len(bucket_boundaries) - 1):
if not bucket_boundaries[i + 1] > bucket_boundaries[i]:
raise ValueError("bucket_boundaries should be strictly increasing.")
# check bucket_batch_sizes: must be list of int and positive
bucket_batch_sizes = param_dict.get('bucket_batch_sizes')
if len(bucket_batch_sizes) != len(bucket_boundaries) + 1:
raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.")
all_int = all(isinstance(item, int) for item in bucket_batch_sizes)
if not all_int:
raise TypeError("bucket_batch_sizes should be a list of int.")
all_non_negative = all(item >= 0 for item in bucket_batch_sizes)
if not all_non_negative:
raise ValueError("bucket_batch_sizes cannot contain any negative numbers.")
if param_dict.get('pad_info') is not None:
check_type(param_dict["pad_info"], "pad_info", dict)
for k, v in param_dict.get('pad_info').items():
check_pad_info(k, v)
return method(*args, **kwargs)
return new_method
def check_batch(method):
"""check the input arguments of batch."""

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save