!8051 BatchNode should support py::function for per_batch_map and batch_size_func, paving way for pybind switch to use IR

Merge pull request !8051 from ZiruiWu/batch_cpp_api_pyfunc
pull/8051/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 7e1b1f280a

@ -478,10 +478,7 @@ std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &dat
// Function to create a Batch dataset
BatchDataset::BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder) {
// Default values
std::vector<std::string> cols_to_map = {};
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map;
bool pad = false;
auto ds = std::make_shared<BatchNode>(input->IRNode(), batch_size, drop_remainder, pad, cols_to_map, pad_map);
auto ds = std::make_shared<BatchNode>(input->IRNode(), batch_size, drop_remainder);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

@ -28,16 +28,29 @@ namespace mindspore {
namespace dataset {
namespace api {
#ifdef ENABLE_PYTHON
// constructor #1, called by Pybind
BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad,
std::vector<std::string> cols_to_map,
const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names,
py::function batch_size_func, py::function batch_map_func,
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map)
: batch_size_(batch_size),
drop_remainder_(drop_remainder),
pad_(pad),
cols_to_map_(cols_to_map),
in_col_names_(in_col_names),
out_col_names_(out_col_names),
batch_size_func_(batch_size_func),
batch_map_func_(batch_map_func),
pad_map_(pad_map) {
this->children.push_back(child);
}
#endif
// constructor #2, called by C++ API
BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder)
: batch_size_(batch_size), drop_remainder_(drop_remainder), pad_(false) {
this->children.push_back(child);
}
Status BatchNode::ValidateParams() {
if (batch_size_ <= 0) {
@ -45,11 +58,20 @@ Status BatchNode::ValidateParams() {
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (!cols_to_map_.empty()) {
std::string err_msg = "BatchNode: cols_to_map functionality is not implemented in C++; this should be left empty.";
#ifdef ENABLE_PYTHON
if (batch_map_func_ && pad_) {
std::string err_msg = "BatchNode: per_batch_map and pad should not be used at the same time.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (batch_map_func_ && in_col_names_.empty()) {
std::string err_msg = "BatchNode: in_col_names cannot be empty when per_batch_map is used.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
#endif
return Status::OK();
}
@ -58,16 +80,19 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() {
std::vector<std::shared_ptr<DatasetOp>> node_ops;
#ifdef ENABLE_PYTHON
py::function noop;
node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
cols_to_map_, cols_to_map_, noop, noop, pad_map_));
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_,
pad_map_));
// need to insert a project when per_batch_func changes the number of columns
if (!out_col_names_.empty()) {
auto project_op = std::make_shared<ProjectOp>(out_col_names_);
node_ops.push_back(project_op);
}
#else
node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
cols_to_map_, pad_map_));
in_col_names_, pad_map_));
#endif
// Until py::function is implemented for C++ API, there is no need for a project op to be inserted after batch
// because project is only needed when batch op performs per_batch_map. This per_batch_map is a pyfunc
return node_ops;
}

@ -31,10 +31,16 @@ namespace api {
class BatchNode : public DatasetNode {
public:
/// \brief Constructor
#ifdef ENABLE_PYTHON
/// \brief Constructor #1, for Python API to create a BatchNode
BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad,
std::vector<std::string> cols_to_map,
const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names,
py::function batch_size_func, py::function batch_map_func,
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map);
#endif
/// \brief Constructor #2 for C++ API to create a BatchNode
BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder);
/// \brief Destructor
~BatchNode() = default;
@ -51,7 +57,12 @@ class BatchNode : public DatasetNode {
int32_t batch_size_;
bool drop_remainder_;
bool pad_;
std::vector<std::string> cols_to_map_;
std::vector<std::string> in_col_names_;
std::vector<std::string> out_col_names_;
#ifdef ENABLE_PYTHON
py::function batch_size_func_;
py::function batch_map_func_;
#endif
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_;
};

Loading…
Cancel
Save