|
|
|
@ -28,16 +28,29 @@ namespace mindspore {
|
|
|
|
|
namespace dataset {
|
|
|
|
|
namespace api {
|
|
|
|
|
|
|
|
|
|
#ifdef ENABLE_PYTHON
|
|
|
|
|
// constructor #1, called by Pybind
|
|
|
|
|
BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad,
|
|
|
|
|
std::vector<std::string> cols_to_map,
|
|
|
|
|
const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names,
|
|
|
|
|
py::function batch_size_func, py::function batch_map_func,
|
|
|
|
|
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map)
|
|
|
|
|
: batch_size_(batch_size),
|
|
|
|
|
drop_remainder_(drop_remainder),
|
|
|
|
|
pad_(pad),
|
|
|
|
|
cols_to_map_(cols_to_map),
|
|
|
|
|
in_col_names_(in_col_names),
|
|
|
|
|
out_col_names_(out_col_names),
|
|
|
|
|
batch_size_func_(batch_size_func),
|
|
|
|
|
batch_map_func_(batch_map_func),
|
|
|
|
|
pad_map_(pad_map) {
|
|
|
|
|
this->children.push_back(child);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
// constructor #2, called by C++ API
|
|
|
|
|
BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder)
|
|
|
|
|
: batch_size_(batch_size), drop_remainder_(drop_remainder), pad_(false) {
|
|
|
|
|
this->children.push_back(child);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status BatchNode::ValidateParams() {
|
|
|
|
|
if (batch_size_ <= 0) {
|
|
|
|
@ -45,11 +58,20 @@ Status BatchNode::ValidateParams() {
|
|
|
|
|
MS_LOG(ERROR) << err_msg;
|
|
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
|
|
|
|
}
|
|
|
|
|
if (!cols_to_map_.empty()) {
|
|
|
|
|
std::string err_msg = "BatchNode: cols_to_map functionality is not implemented in C++; this should be left empty.";
|
|
|
|
|
|
|
|
|
|
#ifdef ENABLE_PYTHON
|
|
|
|
|
if (batch_map_func_ && pad_) {
|
|
|
|
|
std::string err_msg = "BatchNode: per_batch_map and pad should not be used at the same time.";
|
|
|
|
|
MS_LOG(ERROR) << err_msg;
|
|
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (batch_map_func_ && in_col_names_.empty()) {
|
|
|
|
|
std::string err_msg = "BatchNode: in_col_names cannot be empty when per_batch_map is used.";
|
|
|
|
|
MS_LOG(ERROR) << err_msg;
|
|
|
|
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -58,16 +80,19 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() {
|
|
|
|
|
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
|
|
|
|
|
|
|
|
|
#ifdef ENABLE_PYTHON
|
|
|
|
|
py::function noop;
|
|
|
|
|
node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
|
|
|
|
|
cols_to_map_, cols_to_map_, noop, noop, pad_map_));
|
|
|
|
|
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_,
|
|
|
|
|
pad_map_));
|
|
|
|
|
// need to insert a project when per_batch_func changes the number of columns
|
|
|
|
|
if (!out_col_names_.empty()) {
|
|
|
|
|
auto project_op = std::make_shared<ProjectOp>(out_col_names_);
|
|
|
|
|
node_ops.push_back(project_op);
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
|
|
|
|
|
cols_to_map_, pad_map_));
|
|
|
|
|
in_col_names_, pad_map_));
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
// Until py::function is implemented for C++ API, there is no need for a project op to be inserted after batch
|
|
|
|
|
// because project is only needed when batch op performs per_batch_map. This per_batch_map is a pyfunc
|
|
|
|
|
return node_ops;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|