diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index be133ea7a9..4a5dac198f 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -207,6 +207,8 @@ int DEPipeline::GetBatchSize() const { return batch_size_; } int DEPipeline::GetRepeatCount() const { return repeat_num_; } +float ToFloat(const py::handle &handle) { return py::reinterpret_borrow(handle); } + int ToInt(const py::handle &handle) { return py::reinterpret_borrow(handle); } bool ToBool(const py::handle &handle) { return py::reinterpret_borrow(handle); } @@ -621,6 +623,21 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr if (key == "input_columns") { (void)builder->SetColumnsToMap(ToStringVector(value)); } + if (key == "pad_info") { + std::map> pad_info; + for (auto p : py::reinterpret_borrow(value)) { + if (!p.second.is_none()) { + py::tuple tp = py::reinterpret_borrow(p.second); + CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)"); + TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]); + float pad_val = tp[1].is_none() ? 0 : ToFloat(tp[1]); + (void)pad_info.insert({ToString(p.first), {shape, pad_val}}); + } else { // tuple is None + (void)pad_info.insert({ToString(p.first), {TensorShape({}), 0}}); + } + } + (void)builder->SetPaddingMap(pad_info, true); + } } } diff --git a/mindspore/ccsrc/dataset/core/tensor.h b/mindspore/ccsrc/dataset/core/tensor.h index 74da40c293..4a41d4bd20 100644 --- a/mindspore/ccsrc/dataset/core/tensor.h +++ b/mindspore/ccsrc/dataset/core/tensor.h @@ -93,10 +93,10 @@ class Tensor { // Copy raw data of a array based on shape and strides to the destination pointer // @param dst Pointer to the destination array where the content is to be copied - // @param src Pointer to the source of stided array to be copied + // @param src Pointer to the source of strided array to be copied // @param shape - shape of the source array // @param strides - strides of the source array - // @param type_size - number of bytes needed to store one array elment's type + // @param type_size - number of bytes needed to store one array element's type // @return Status Code static Status CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, std::vector strides, uint8_t type_size); @@ -138,10 +138,10 @@ class Tensor { return Status::OK(); } + // fill tensor with Zeros Status Zero() { dsize_t size = SizeInBytes(); - int retCode = memset_sp(StartAddr(), size, 0, size); - if (retCode != 0) return Status(StatusCode::kUnexpectedError, "Failed to fill tensor with zeroes."); + CHECK_FAIL_RETURN_UNEXPECTED(memset_sp(StartAddr(), size, 0, size) == 0, "Failed to fill tensor with zeroes."); return Status::OK(); } @@ -154,10 +154,7 @@ class Tensor { int64_t cellSize = type_.SizeInBytes(); if ((data_ != nullptr) && type_.IsCompatible()) { for (dsize_t i = 0; i < Size(); i++) { - int retCode = memcpy_s((data_ + i * cellSize), cellSize, &value, cellSize); - if (retCode != 0) { - return Status(StatusCode::kUnexpectedError, "Failed to fill tensor."); - } + CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s((data_ + i * cellSize), cellSize, &value, cellSize) == 0, "memcpy err"); } return Status::OK(); } else { diff --git a/mindspore/ccsrc/dataset/core/tensor_shape.cc b/mindspore/ccsrc/dataset/core/tensor_shape.cc index 24520dc381..3a6514034f 100644 --- a/mindspore/ccsrc/dataset/core/tensor_shape.cc +++ b/mindspore/ccsrc/dataset/core/tensor_shape.cc @@ -87,8 +87,12 @@ TensorShape::TensorShape(const TensorShape &shape) : raw_shape_(*GlobalContext:: TensorShape::TensorShape(py::list l) : raw_shape_(*GlobalContext::Instance()->int_allocator()) { std::vector list_c; - for (auto i : l) { - list_c.push_back(i.cast()); + for (auto &i : l) { + if (!i.is_none()) { + list_c.push_back(i.cast()); + } else { + list_c.push_back(TensorShape::kDimUnknown); + } } AddListToShape(list_c); } diff --git a/mindspore/ccsrc/dataset/core/tensor_shape.h b/mindspore/ccsrc/dataset/core/tensor_shape.h index f908a00ecc..230b36cda2 100644 --- a/mindspore/ccsrc/dataset/core/tensor_shape.h +++ b/mindspore/ccsrc/dataset/core/tensor_shape.h @@ -65,6 +65,10 @@ class TensorShape { // @param shape TensorShape(const TensorShape &shape); + // construct a TensorShape via a python list + // @param py::list l - a list object from python + explicit TensorShape(py::list l); + ~TensorShape() = default; // Create a scalar Shape (i.e., empty shape with mKnown = true) @@ -142,8 +146,6 @@ class TensorShape { return out; } - explicit TensorShape(py::list l); - py::list AsPyList(); // Checks if the given index is a valid index for this tensor. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc index a86633e5b4..018ff99e52 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc @@ -14,15 +14,20 @@ * limitations under the License. */ #include "dataset/engine/datasetops/batch_op.h" + #include #include + #include "common/utils.h" +#include "dataset/core/pybind_support.h" #include "dataset/engine/data_buffer.h" #include "dataset/engine/db_connector.h" +using float16 = Eigen::half; + namespace mindspore { namespace dataset { -BatchOp::Builder::Builder(int32_t batch_size) : builder_drop_(false) { +BatchOp::Builder::Builder(int32_t batch_size) : builder_drop_(false), builder_pad_(false), builder_pad_map_({}) { builder_batch_size_ = batch_size; std::shared_ptr cfg = GlobalContext::config_manager(); builder_num_workers_ = cfg->num_parallel_workers(); @@ -31,8 +36,9 @@ BatchOp::Builder::Builder(int32_t batch_size) : builder_drop_(false) { Status BatchOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(builder_batch_size_, builder_drop_, builder_op_connector_size_, builder_num_workers_, - builder_cols_to_map_, builder_batch_size_func_, builder_batch_map_func_); + *ptr = std::make_shared(builder_batch_size_, builder_drop_, builder_pad_, builder_op_connector_size_, + builder_num_workers_, builder_cols_to_map_, builder_batch_size_func_, + builder_batch_map_func_, builder_pad_map_); return Status::OK(); } @@ -44,14 +50,17 @@ Status BatchOp::Builder::SanityCheck() { return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); } -BatchOp::BatchOp(int32_t batch_size, bool drop, int32_t op_queue_size, int32_t num_workers, - const std::vector &cols_to_map, py::function batch_size_func, py::function batch_map_func) +BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, + const std::vector &cols_to_map, py::function batch_size_func, py::function batch_map_func, + std::map> pad_map) : ParallelOp(num_workers, op_queue_size), start_batch_size_(batch_size), drop_(drop), - input_column_names_(cols_to_map), + pad_(pad), + pyfunc_column_names_(cols_to_map), batch_size_func_(batch_size_func), - batch_map_func_(batch_map_func) { + batch_map_func_(batch_map_func), + pad_info_(pad_map) { worker_queues_.Init(num_workers, op_queue_size); } @@ -181,7 +190,8 @@ Status BatchOp::WorkerEntry(int32_t workerId) { Status BatchOp::MakeBatchedBuffer(std::pair, CBatchInfo> table_pair, std::unique_ptr *db) { RETURN_UNEXPECTED_IF_NULL(table_pair.first); - if (!input_column_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc + if (!pyfunc_column_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc + if (pad_) RETURN_IF_NOT_OK(PadColumns(&table_pair)); // do padding if needed (*db) = std::make_unique(table_pair.second.batch_num_, DataBuffer::kDeBFlagNone); std::unique_ptr dest_table = std::make_unique(); RETURN_IF_NOT_OK(BatchRows(&table_pair.first, &dest_table, table_pair.first->size())); @@ -206,8 +216,8 @@ Status BatchOp::EoeReceived(int32_t) { Status BatchOp::MapColumns(std::pair, CBatchInfo> *table_pair) { TensorBatchTable input_table; - input_table.reserve(input_column_names_.size()); - for (std::string col_name : input_column_names_) { + input_table.reserve(pyfunc_column_names_.size()); + for (std::string col_name : pyfunc_column_names_) { if (column_name_map_.find(col_name) == column_name_map_.end()) { RETURN_STATUS_UNEXPECTED("column : '" + col_name + "' does not exist\n"); } @@ -225,8 +235,8 @@ Status BatchOp::MapColumns(std::pair, CBatchInfo> RETURN_IF_NOT_OK(InvokeBatchMapFunc(&input_table, &output_table, table_pair->second)); // Write back to TensorQTable - for (size_t input_idx = 0; input_idx < input_column_names_.size(); input_idx++) { - size_t col_idx = static_cast(column_name_map_[input_column_names_[input_idx]]); + for (size_t input_idx = 0; input_idx < pyfunc_column_names_.size(); input_idx++) { + size_t col_idx = static_cast(column_name_map_[pyfunc_column_names_[input_idx]]); size_t row_id = 0; for (TensorRow &row : *(table_pair->first)) { row[col_idx] = std::move(output_table[input_idx][row_id++]); @@ -290,8 +300,8 @@ Status BatchOp::InvokeBatchMapFunc(TensorBatchTable *input, TensorBatchTable *ou py::object ret_py_obj = batch_map_func_(*input_args); // Parse batch map return value py::tuple ret_tuple = py::cast(ret_py_obj); - if (ret_tuple.size() != input_column_names_.size() || !py::isinstance(ret_tuple)) { - return Status(StatusCode::kPyFuncException, "Batch map function should return an tuple if size(input_columns)"); + if (ret_tuple.size() != pyfunc_column_names_.size() || !py::isinstance(ret_tuple)) { + return Status(StatusCode::kPyFuncException, "Batch map function should return a tuple"); } for (size_t i = 0; i < ret_tuple.size(); i++) { TensorBatch output_batch; @@ -311,5 +321,142 @@ Status BatchOp::InvokeBatchMapFunc(TensorBatchTable *input, TensorBatchTable *ou } return Status(StatusCode::kOK); } + +Status BatchOp::PadTensor(std::shared_ptr src, std::shared_ptr *dst, + const std::vector &pad_shape, float pad_val) { + CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr"); + if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) { + (*dst) = src; // if no padding, copy the pointer + } else { + CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed"); + RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type())); + auto tensor_type = src->type().value(); + if (pad_val == 0) { // if pad with zero, don't care what type it is + RETURN_IF_NOT_OK((*dst)->Zero()); + } else if (tensor_type == DataType::DE_INT8) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_BOOL) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_UINT8) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_INT16) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_FLOAT16) { + RETURN_IF_NOT_OK((*dst)->Fill(static_cast(pad_val))); + } else if (tensor_type == DataType::DE_UINT16) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_INT32) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_UINT32) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_INT64) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_UINT64) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_FLOAT32) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_FLOAT64) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else { + RETURN_STATUS_UNEXPECTED("Incorrect/Unknown tensor type"); + } + std::vector cur_ind(src->Rank(), 0), src_s(src->Rank(), 1), dst_s(src->Rank(), 1); + for (dsize_t i = src->Rank() - 2; i >= 0; i--) { + src_s[i] = src->shape()[i + 1] * src_s[i + 1]; + dst_s[i] = pad_shape[i + 1] * dst_s[i + 1]; + } + RETURN_IF_NOT_OK(PadHelper(src, *dst, cur_ind, src_s, dst_s, 0)); + } + return Status::OK(); +} // namespace dataset + +Status BatchOp::PadColumns(std::pair, CBatchInfo> *table_pair) { + RETURN_UNEXPECTED_IF_NULL(table_pair); // placeholder for now, might need this in the future + CHECK_FAIL_RETURN_UNEXPECTED(table_pair->first->front().size() == column_name_map_.size(), "col_name_map mismatch"); + std::vector pad_vals(column_name_map_.size(), 0); // value to pad each column's tensor with, default 0 + std::set pad_cols; + // padded_shape provided by user, maximum shapes of current batch of tensors + std::vector> pad_shapes(column_name_map_.size()), max_shapes(column_name_map_.size()); + RETURN_IF_NOT_OK(UnpackPadInfo(&pad_cols, &pad_vals, &pad_shapes)); + + // init each shape in max_shape to {-1,-1...} init each unspecified shape in pad_shape to -1 as well + for (size_t col_id : pad_cols) { + max_shapes[col_id] = std::vector(table_pair->first->front()[col_id]->Rank(), -1); + if (pad_shapes[col_id].empty()) pad_shapes[col_id] = max_shapes[col_id]; // fill pad shape with -1 + CHECK_FAIL_RETURN_UNEXPECTED(pad_shapes[col_id].size() == max_shapes[col_id].size(), "wrong rank in pad_shape"); + } + + // calculate maximum shape for each column that needs to be padded + for (const TensorRow &row : *(table_pair->first)) { // iterator each row in a batch + for (size_t col_id : pad_cols) { // iterator each tensor in a row + CHECK_FAIL_RETURN_UNEXPECTED(row[col_id]->Rank() == max_shapes[col_id].size(), + "Tensor to be padded together need to have the same rank"); + for (size_t dim = 0; dim < row[col_id]->Rank(); dim++) { // pick the largest number in each dimension + max_shapes[col_id][dim] = std::max(max_shapes[col_id][dim], row[col_id]->shape()[dim]); + } + } + } + + // if user sets a dimension to -1 (None in python), use the max value for current dimension + for (size_t col_id : pad_cols) { + for (size_t dim = 0; dim < pad_shapes[col_id].size(); dim++) { + if (pad_shapes[col_id][dim] < 0) pad_shapes[col_id][dim] = max_shapes[col_id][dim]; + } + } + + // call pad on each tensor that needs to be padded + for (TensorRow &row : *(table_pair->first)) { + for (size_t col_id : pad_cols) { + std::shared_ptr pad_tensor; + RETURN_IF_NOT_OK(PadTensor(row[col_id], &pad_tensor, pad_shapes[col_id], pad_vals[col_id])); + row[col_id] = pad_tensor; + } + } + return Status::OK(); +} + +Status BatchOp::UnpackPadInfo(std::set *pad_cols, std::vector *pad_vals, + std::vector> *pad_shapes) { + if (pad_info_.empty()) { // if pad_info empty, pad every columns automatically + for (dsize_t col_id = 0; col_id < column_name_map_.size(); col_id++) { + pad_cols->insert(col_id); + } + } else { + for (auto p : pad_info_) { + CHECK_FAIL_RETURN_UNEXPECTED(column_name_map_.find(p.first) != column_name_map_.end(), + "no column exists with name:" + p.first); + dsize_t col_id = static_cast(column_name_map_[p.first]); + CHECK_FAIL_RETURN_UNEXPECTED(col_id < pad_vals->size() && col_id < pad_shapes->size(), "col_id out of bound"); + pad_cols->insert(col_id); + (*pad_vals)[col_id] = p.second.second; // set pad values + (*pad_shapes)[col_id] = p.second.first.AsVector(); // empty vector if shape is unknown + } + } + return Status::OK(); +} + +Status BatchOp::PadHelper(std::shared_ptr src, std::shared_ptr dst, std::vector cur_ind, + const std::vector &src_s, const std::vector &dst_s, size_t cur_dim) { + if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data + uint8_t type_size = src->type().SizeInBytes(); + size_t len = std::min(src->shape()[cur_dim], dst->shape()[cur_dim]) * type_size; + dsize_t src_flat_ind = 0, dst_flat_ind = 0; + for (size_t i = 0; i < src->Rank(); i++) { + src_flat_ind += src_s[i] * cur_ind[i]; + dst_flat_ind += dst_s[i] * cur_ind[i]; + } + unsigned char *src_addr = src->StartAddr() + src_flat_ind * type_size; + unsigned char *dst_addr = dst->StartAddr() + dst_flat_ind * type_size; + CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error"); + } else { // not the last dimension, keep doing recursion + dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]); + for (dsize_t i = 0; i < min_ind; i++) { + cur_ind[cur_dim] = i; + RETURN_IF_NOT_OK(PadHelper(src, dst, cur_ind, src_s, dst_s, cur_dim + 1)); + } + } + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h index 32d386e3c9..f17239e378 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h @@ -16,8 +16,11 @@ #ifndef DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ #define DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ +#include +#include #include #include +#include #include #include #include @@ -44,10 +47,6 @@ class BatchOp : public ParallelOp { // @param int32_t batch_size explicit Builder(int32_t batch_size); - // Builder constructor for Batch, batch size function needs to be specified - // @param py::function batch_size_func - explicit Builder(py::function batch_size_func); - // Default destructor ~Builder() = default; @@ -67,6 +66,12 @@ class BatchOp : public ParallelOp { return *this; } + Builder &SetPaddingMap(const std::map> &pad_map, bool pad = true) { + builder_pad_ = pad; + builder_pad_map_ = pad_map; + return *this; + } + // set connector size for batch // @param int32_t op_conn_size // @return Builder & reference to builder class object @@ -109,11 +114,12 @@ class BatchOp : public ParallelOp { Status SanityCheck(); bool builder_drop_; + bool builder_pad_; int32_t builder_batch_size_; int32_t builder_num_workers_; int32_t builder_op_connector_size_; std::vector builder_cols_to_map_; - + std::map> builder_pad_map_; py::function builder_batch_size_func_; py::function builder_batch_map_func_; }; @@ -143,8 +149,9 @@ class BatchOp : public ParallelOp { // @param int32_t op_queue_size // @param int32_t rows_per_buf // @param int32_t num_workers - BatchOp(int32_t batch_size, bool drop, int32_t op_queue_size, int32_t num_workers, const std::vector &, - py::function batch_size_func, py::function batch_map_func); + BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, + const std::vector &, py::function batch_size_func, py::function batch_map_func, + std::map> pad_map); // BatchOp destructor ~BatchOp() {} @@ -176,7 +183,28 @@ class BatchOp : public ParallelOp { // @return Status - The error code return Status operator()() override; + // Pad input tensor according pad_shape, need to have same rank. + // @param std::shared_ptr src - tensor to pad from + // @param std::shared_ptr *dst - return tensor padded + // @param std::vector pad_shape - shape to pad to + // @param float pad_val - value to pad with + // @return - The error code return + Status PadTensor(std::shared_ptr src, std::shared_ptr *dst, const std::vector &pad_shape, + float pad_val); + private: + // recursive helper function. This function could be very expensive if called on a multi-dimensional tensor + // it is only meant to be called by PadTensor. + // @tparam T - type of tensor and fill value + // @param std::shared_ptr src - Tensor to pad from + // @param std::shared_ptr* dst - Tensor to pad to, return value + // @param std::vector cur_ind - recursion helper + // @param T pad_val - value to pad tensor with + // @param size_t cur_dim - recursion helper + // @return Status - The error code return + Status PadHelper(std::shared_ptr src, std::shared_ptr dst, std::vector cur_ind, + const std::vector &src_s, const std::vector &dst_s, size_t cur_dim = 0); + // Worker thread for doing the memcpy of batch // @param int32_t param workerId // @return Status - The error code return @@ -199,6 +227,16 @@ class BatchOp : public ParallelOp { // @return Status - The error code return Status MapColumns(std::pair, CBatchInfo> *table_pair); + // @param std::set *cols, col ids to perform pad on + // @param std::vector *vals, default padding value for each column + // @param std::vector> *shapes, padding shape specified by user + // @return Status - The error code return + Status UnpackPadInfo(std::set *cols, std::vector *vals, std::vector> *shapes); + + // @param table_pair + // @return Status - The error code return + Status PadColumns(std::pair, CBatchInfo> *table_pair); + // the number of thread pulling from the mOutConnector of the Op below // @return int32_t, 1 int32_t num_consumers() const override { return 1; } @@ -220,19 +258,15 @@ class BatchOp : public ParallelOp { Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info); int32_t start_batch_size_; - bool drop_; - // Name of the columns to perform map op on - std::vector input_column_names_; - // Iterator for fetching - std::unique_ptr child_iterator_; - // Map of column_name: column_index - std::unordered_map column_name_map_; - // Internal queue for task distribution - QueueList, CBatchInfo>> worker_queues_; - // Function pointer of batch size function - py::function batch_size_func_; - // Function pointer of per batch map function - py::function batch_map_func_; + bool drop_; // bool for whether to drop remainder or not + bool pad_; // bool for whether to perform padding on tensor + std::vector pyfunc_column_names_; // Name of the columns to perform map op on + std::map> pad_info_; // column names to perform padding on + std::unique_ptr child_iterator_; // child iterator for fetching TensorRows 1 by 1 + std::unordered_map column_name_map_; // Map of column_name: column_index + QueueList, CBatchInfo>> worker_queues_; // internal queue for syncing worker + py::function batch_size_func_; // Function pointer of batch size function + py::function batch_map_func_; // Function pointer of per batch map function }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 06b740bb6b..73bd025e19 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -40,7 +40,8 @@ from mindspore._c_expression import typing from mindspore import log as logger from . import samplers from .iterators import DictIterator, TupleIterator -from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \ +from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ + check_rename, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset @@ -163,7 +164,7 @@ class Dataset: @check_batch def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, - input_columns=None): + input_columns=None, pad_info=None): """ Combines batch_size number of consecutive rows into batches. @@ -181,7 +182,7 @@ class Dataset: drop_remainder (bool, optional): Determines whether or not to drop the last possibly incomplete batch (default=False). If True, and if there are less than batch_size rows available to make the last batch, then those rows will - be dropped and not propogated to the child node. + be dropped and not propagated to the child node. num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None). per_batch_map (callable, optional): Per batch map callable. A callable which takes (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represent a batch of @@ -189,6 +190,8 @@ class Dataset: last parameter of the callable should always be a BatchInfo object. input_columns (list of string, optional): List of names of the input columns. The size of the list should match with signature of per_batch_map callable. + pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)} + would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0. Returns: BatchDataset, dataset batched. @@ -200,7 +203,8 @@ class Dataset: >>> # and drops the last incomplete batch if there is one. >>> data = data.batch(100, True) """ - return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns) + return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns, + pad_info) @check_sync_wait def sync_wait(self, condition_name, num_batch=1, callback=None): @@ -1026,13 +1030,26 @@ class BatchDataset(DatasetOp): Args: input_dataset (Dataset): Input Dataset to be batched. - batch_size (int): The size of the batch. - drop_remainder (bool, optional): Whether drop the remainder batch of data (drop_remainder=False). - If True, the last incomplete batch will be dropped. + batch_size (int or function): The number of rows each batch is created with. An + int or callable which takes exactly 1 parameter, BatchInfo. + drop_remainder (bool, optional): Determines whether or not to drop the last + possibly incomplete batch (default=False). If True, and if there are less + than batch_size rows available to make the last batch, then those rows will + be dropped and not propagated to the child node. + num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None). + per_batch_map (callable, optional): Per batch map callable. A callable which takes + (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represent a batch of + Tensors on a given column. The number of lists should match with number of entries in input_columns. The + last parameter of the callable should always be a BatchInfo object. + input_columns (list of string, optional): List of names of the input columns. The size of the list should + match with signature of per_batch_map callable. + pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)} + would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0. + """ def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, - per_batch_map=None, input_columns=None): + per_batch_map=None, input_columns=None, pad_info=None): super().__init__(num_parallel_workers) if BatchDataset._is_ancestor_of_repeat(input_dataset): @@ -1044,6 +1061,7 @@ class BatchDataset(DatasetOp): self.drop_remainder = drop_remainder self.per_batch_map = per_batch_map self.input_columns = input_columns + self.pad_info = pad_info self.input.append(input_dataset) input_dataset.output.append(self) self._input_indexs = input_dataset.input_indexs @@ -1054,6 +1072,7 @@ class BatchDataset(DatasetOp): args["drop_remainder"] = self.drop_remainder args["per_batch_map"] = self.per_batch_map args["input_columns"] = self.input_columns + args["pad_info"] = self.pad_info return args def get_dataset_size(self): @@ -2702,6 +2721,7 @@ class TFRecordDataset(SourceDataset): >>> # 3) get all rows from dataset_files with schema file "./schema.json": >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json") """ + @check_tfrecorddataset def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False): @@ -3551,6 +3571,7 @@ class CelebADataset(SourceDataset): args["shard_id"] = self.shard_id return args + class TextFileDataset(SourceDataset): """ A source dataset that reads and parses datasets stored on disk in text format. diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 4f1bb2c2d7..fd6ecfffb0 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -324,6 +324,7 @@ def check_sampler_shuffle_shard_options(param_dict): def check_imagefolderdatasetv2(method): """A wrapper that wrap a parameter checker to the original Dataset(ImageFolderDatasetV2).""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -356,6 +357,7 @@ def check_imagefolderdatasetv2(method): def check_mnist_cifar_dataset(method): """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset).""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -382,6 +384,7 @@ def check_mnist_cifar_dataset(method): def check_manifestdataset(method): """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset).""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -414,6 +417,7 @@ def check_manifestdataset(method): def check_tfrecorddataset(method): """A wrapper that wrap a parameter checker to the original Dataset(TFRecordDataset).""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -444,6 +448,7 @@ def check_tfrecorddataset(method): def check_vocdataset(method): """A wrapper that wrap a parameter checker to the original Dataset(VOCDataset).""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -470,6 +475,7 @@ def check_vocdataset(method): def check_celebadataset(method): """A wrapper that wrap a parameter checker to the original Dataset(CelebADataset).""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -510,6 +516,7 @@ def check_celebadataset(method): def check_minddataset(method): """A wrapper that wrap a parameter checker to the original Dataset(MindDataset).""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -541,6 +548,7 @@ def check_minddataset(method): def check_generatordataset(method): """A wrapper that wrap a parameter checker to the original Dataset(GeneratorDataset).""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -628,8 +636,25 @@ def check_columns(columns, name): raise TypeError("{} should be either a list of strings or a single string.".format(name)) +def check_pad_info(key, val): + """check the key and value pair of pad_info in batch""" + check_type(key, "key in pad_info", str) + if val is not None: + assert len(val) == 2, "value of pad_info should be a tuple of size 2" + check_type(val, "value in pad_info", tuple) + if val[0] is not None: + check_type(val[0], "pad_shape", list) + for dim in val[0]: + if dim is not None: + check_type(dim, "dim in pad_shape", int) + assert dim > 0, "pad shape should be positive integers" + if val[1] is not None: + check_type(val[1], "pad_value", (int, float)) + + def check_batch(method): """check the input arguments of batch.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -648,6 +673,14 @@ def check_batch(method): check_param_type(nreq_param_bool, param_dict, bool) + if (param_dict.get('pad_info') is not None) and (param_dict.get('per_batch_map') is not None): + raise ValueError("pad_info and per_batch_map can't both be set") + + if param_dict.get('pad_info') is not None: + check_type(param_dict["pad_info"], "pad_info", dict) + for k, v in param_dict.get('pad_info').items(): + check_pad_info(k, v) + for param_name in nreq_param_columns: param = param_dict.get(param_name) if param is not None: @@ -687,6 +720,7 @@ def check_sync_wait(method): def check_shuffle(method): """check the input arguments of shuffle.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -705,6 +739,7 @@ def check_shuffle(method): def check_map(method): """check the input arguments of map.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -729,6 +764,7 @@ def check_map(method): def check_filter(method): """"check the input arguments of filter.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -749,6 +785,7 @@ def check_filter(method): def check_repeat(method): """check the input arguments of repeat.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -764,6 +801,7 @@ def check_repeat(method): def check_skip(method): """check the input arguments of skip.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -780,6 +818,7 @@ def check_skip(method): def check_take(method): """check the input arguments of take.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -794,6 +833,7 @@ def check_take(method): def check_zip(method): """check the input arguments of zip.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -811,6 +851,7 @@ def check_zip(method): def check_zip_dataset(method): """check the input arguments of zip method in `Dataset`.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -830,6 +871,7 @@ def check_zip_dataset(method): def check_rename(method): """check the input arguments of rename.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -849,6 +891,7 @@ def check_rename(method): def check_project(method): """check the input arguments of project.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -876,6 +919,7 @@ def check_shape(shape, name): def check_add_column(method): """check the input arguments of add_column.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -905,6 +949,7 @@ def check_add_column(method): def check_textfiledataset(method): """A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset).""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) diff --git a/tests/ut/cpp/dataset/batch_op_test.cc b/tests/ut/cpp/dataset/batch_op_test.cc index 504cac51e5..866ebc9b19 100644 --- a/tests/ut/cpp/dataset/batch_op_test.cc +++ b/tests/ut/cpp/dataset/batch_op_test.cc @@ -30,16 +30,14 @@ namespace common = mindspore::common; namespace de = mindspore::dataset; using namespace mindspore::dataset; -using mindspore::MsLogLevel::ERROR; -using mindspore::ExceptionType::NoExceptionType; using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; class MindDataTestBatchOp : public UT::DatasetOpTesting { protected: - }; - std::shared_ptr Batch(int32_t batch_size = 1, bool drop = false, int rows_per_buf = 2) { Status rc; std::shared_ptr op; @@ -93,10 +91,8 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatch) { rc = di.GetNextAsMap(&tensor_map); EXPECT_TRUE(rc.IsOk()); std::shared_ptr t; - rc = de::Tensor::CreateTensor(&t, - TensorImpl::kFlexible, de::TensorShape({12, 1}), - de::DataType(DataType::DE_INT64), - (unsigned char *) payload); + rc = de::Tensor::CreateTensor(&t, TensorImpl::kFlexible, de::TensorShape({12, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload); EXPECT_TRUE(rc.IsOk()); // verify the actual data in Tensor is correct EXPECT_EQ(*t == *tensor_map["col_sint64"], true); @@ -111,7 +107,6 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatch) { EXPECT_EQ(success, true); } - TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) { std::string schema_file = datasets_root_path_ + "/testBatchDataset"; bool success = false; @@ -125,20 +120,14 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) { -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807}; de::DatasetIterator di(tree); std::shared_ptr t1, t2, t3; - rc = de::Tensor::CreateTensor(&t1, - TensorImpl::kFlexible, de::TensorShape({7, 1}), - de::DataType(DataType::DE_INT64), - (unsigned char *) payload); + rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t2, - TensorImpl::kFlexible, de::TensorShape({7, 1}), - de::DataType(DataType::DE_INT64), - (unsigned char *) (payload + 7)); + rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 7)); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t3, - TensorImpl::kFlexible, de::TensorShape({7, 1}), - de::DataType(DataType::DE_INT64), - (unsigned char *) (payload + 2)); + rc = de::Tensor::CreateTensor(&t3, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 2)); EXPECT_TRUE(rc.IsOk()); TensorMap tensor_map; @@ -163,7 +152,6 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) { EXPECT_EQ(success, true); } - TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) { std::string schema_file = datasets_root_path_ + "/testBatchDataset"; bool success = false; @@ -177,25 +165,17 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) { -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807}; de::DatasetIterator di(tree); std::shared_ptr t1, t2, t3, t4; - rc = de::Tensor::CreateTensor(&t1, - TensorImpl::kFlexible, de::TensorShape({7, 1}), - de::DataType(DataType::DE_INT64), - (unsigned char *) payload); + rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t2, - TensorImpl::kFlexible, de::TensorShape({7, 1}), - de::DataType(DataType::DE_INT64), - (unsigned char *) (payload + 7)); + rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 7)); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t3, - TensorImpl::kFlexible, de::TensorShape({7, 1}), - de::DataType(DataType::DE_INT64), - (unsigned char *) (payload + 2)); + rc = de::Tensor::CreateTensor(&t3, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 2)); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t4, - TensorImpl::kFlexible, de::TensorShape({3, 1}), - de::DataType(DataType::DE_INT64), - (unsigned char *) (payload + 9)); + rc = de::Tensor::CreateTensor(&t4, TensorImpl::kFlexible, de::TensorShape({3, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 9)); EXPECT_TRUE(rc.IsOk()); TensorMap tensor_map; @@ -224,7 +204,6 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) { EXPECT_EQ(success, true); } - TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) { std::string schema_file = datasets_root_path_ + "/testBatchDataset"; bool success = false; @@ -238,15 +217,11 @@ TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) { -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807}; de::DatasetIterator di(tree); std::shared_ptr t1, t2; - rc = de::Tensor::CreateTensor(&t1, - TensorImpl::kFlexible, de::TensorShape({7, 1}), - de::DataType(DataType::DE_INT64), - (unsigned char *) payload); + rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({7, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t2, - TensorImpl::kFlexible, de::TensorShape({5, 1}), - de::DataType(DataType::DE_INT64), - (unsigned char *) (payload + 7)); + rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 7)); EXPECT_TRUE(rc.IsOk()); TensorMap tensor_map; @@ -275,7 +250,6 @@ TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) { EXPECT_EQ(success, true); } - TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) { std::string schema_file = datasets_root_path_ + "/testBatchDataset"; bool success = false; @@ -289,15 +263,11 @@ TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) { -9223372036854775807 - 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9223372036854775807}; de::DatasetIterator di(tree); std::shared_ptr t1, t2; - rc = de::Tensor::CreateTensor(&t1, - TensorImpl::kFlexible, de::TensorShape({5, 1}), - de::DataType(DataType::DE_INT64), - (unsigned char *) payload); + rc = de::Tensor::CreateTensor(&t1, TensorImpl::kFlexible, de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload); EXPECT_TRUE(rc.IsOk()); - rc = de::Tensor::CreateTensor(&t2, - TensorImpl::kFlexible, de::TensorShape({5, 1}), - de::DataType(DataType::DE_INT64), - (unsigned char *) (payload + 5)); + rc = de::Tensor::CreateTensor(&t2, TensorImpl::kFlexible, de::TensorShape({5, 1}), de::DataType(DataType::DE_INT64), + (unsigned char *)(payload + 5)); EXPECT_TRUE(rc.IsOk()); TensorMap tensor_map; @@ -325,3 +295,31 @@ TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) { } EXPECT_EQ(success, true); } + +TEST_F(MindDataTestBatchOp, TestSimpleBatchPadding) { + std::string schema_file = datasets_root_path_ + "/testBatchDataset"; + std::shared_ptr op; + std::map> m; + m.insert({"col_1d", std::make_pair(TensorShape({4}), -1)}); + de::BatchOp::Builder(12).SetDrop(false).SetPaddingMap(m, true).Build(&op); + auto tree = Build({Storage(schema_file), op}); + tree->Prepare(); + Status rc = tree->Launch(); + if (rc.IsError()) { + MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << "."; + } else { + int64_t payload[] = {-9223372036854775807 - 1, 1, -1, -1, 2, 3, -1, -1, 4, 5, -1, -1, 6, 7, -1, -1, + 8, 9, -1, -1, 10, 11, -1, -1, 12, 13, -1, -1, 14, 15, -1, -1, + 16, 17, -1, -1, 18, 19, -1, -1, 20, 21, -1, -1, 22, 23, -1, -1}; + std::shared_ptr t; + rc = de::Tensor::CreateTensor(&t, TensorImpl::kFlexible, de::TensorShape({12, 4}), de::DataType(DataType::DE_INT64), + (unsigned char *)payload); + de::DatasetIterator di(tree); + TensorMap tensor_map; + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE((*t) == (*(tensor_map["col_1d"]))); + rc = di.GetNextAsMap(&tensor_map); + EXPECT_TRUE(tensor_map.size() == 0); + EXPECT_TRUE(rc.IsOk()); + } +} diff --git a/tests/ut/python/dataset/test_pad_batch.py b/tests/ut/python/dataset/test_pad_batch.py new file mode 100644 index 0000000000..7cfc34e718 --- /dev/null +++ b/tests/ut/python/dataset/test_pad_batch.py @@ -0,0 +1,213 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import mindspore.dataset as ds +import numpy as np +import time + + +# This UT test tests the following cases + +# 1. padding: input_shape=[x] output_shape=[y] where y > x +# 2. padding in one dimension and truncate in the other. input_shape=[x1,x2] output_shape=[y1,y2] y1>x1 and y2