!9335 Add multiprocessing support to per_batch_map

From: @hfarahat
Reviewed-by: 
Signed-off-by:
pull/9335/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 4f607e9c26

@ -24,7 +24,18 @@ PYBIND_REGISTER(CBatchInfo, 0, ([](const py::module *m) {
(void)py::class_<BatchOp::CBatchInfo>(*m, "CBatchInfo")
.def(py::init<int64_t, int64_t, int64_t>())
.def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num)
.def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
.def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num)
.def(py::pickle(
[](const BatchOp::CBatchInfo &p) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(p.epoch_num_, p.batch_num_, p.total_batch_num_);
},
[](py::tuple t) { // __setstate__
if (t.size() != 3) throw std::runtime_error("Invalid state!");
/* Create a new C++ instance */
BatchOp::CBatchInfo p(t[0].cast<int64_t>(), t[1].cast<int64_t>(), t[2].cast<int64_t>());
return p;
}));
}));
PYBIND_REGISTER(DatasetOp, 0, ([](const py::module *m) {

@ -158,22 +158,14 @@ class Dataset:
if len(self.parent) > 1:
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
ir_children = [d.parse_tree() for d in self.children]
# Bootstrap can only be performed on a copy of the original dataset node.
# Bootstrap on original dataset node will make all iterators share the same process pool
self.iterator_bootstrap()
ir_node = self.parse(ir_children)
return self._alter_node(ir_node)
return ir_node
@staticmethod
def _alter_node(node):
"""
Internal method to add process pool to copied map node.
Returns:
DatasetNode. The altered node.
"""
if isinstance(node, MapDataset):
if node.python_multiprocessing:
# Bootstrap can only be performed on a copy of the original dataset node.
# Bootstrap on original dataset node will make all iterators share the same process pool
node.iterator_bootstrap()
return node
def iterator_bootstrap(self):
pass
@staticmethod
def _noop_mode():
@ -272,7 +264,7 @@ class Dataset:
@check_batch
def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
input_columns=None, output_columns=None, column_order=None, pad_info=None):
input_columns=None, output_columns=None, column_order=None, pad_info=None, python_multiprocessing=False):
"""
Combine batch_size number of consecutive rows into batches.
@ -312,6 +304,8 @@ class Dataset:
same).
pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
python_multiprocessing (bool, optional): Parallelize Python function per_batch_map with multiple worker
processes. This option could be beneficial if the function is computational heavy (default=False).
Returns:
BatchDataset, dataset batched.
@ -339,7 +333,7 @@ class Dataset:
>>> data = data.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize)
"""
return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns,
output_columns, column_order, pad_info)
output_columns, column_order, pad_info, python_multiprocessing)
@check_sync_wait
def sync_wait(self, condition_name, num_batch=1, callback=None):
@ -1835,7 +1829,8 @@ class BatchDataset(Dataset):
"""
def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None,
per_batch_map=None, input_columns=None, output_columns=None, column_order=None, pad_info=None):
per_batch_map=None, input_columns=None, output_columns=None, column_order=None, pad_info=None,
python_multiprocessing=False):
super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers)
if BatchDataset._is_ancestor_of_repeat(input_dataset):
@ -1858,6 +1853,10 @@ class BatchDataset(Dataset):
self.pad = bool(pad_info is not None)
self.pad_info = replace_none(pad_info, dict())
self.python_multiprocessing = python_multiprocessing
self.process_pool = None
self.hook = None
def parse(self, children=None):
return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, self.pad,
self.input_columns, self.output_columns,
@ -1923,9 +1922,32 @@ class BatchDataset(Dataset):
new_op.output_columns = copy.deepcopy(self.output_columns, memodict)
new_op.column_order = copy.deepcopy(self.column_order, memodict)
new_op.pad = self.pad
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
new_op.hook = copy.deepcopy(self.hook, memodict)
new_op.pad_info = copy.deepcopy(self.pad_info, memodict)
return new_op
# Iterator bootstrap will be called on iterator construction.
# A deep copy of Dataset object is created prior of iterator_bootstrap.
# This method will create per iterator process pool and bind pyfunc execution to the pool.
def iterator_bootstrap(self):
"""
Per iterator bootstrap callback.
"""
if self.python_multiprocessing:
# Construct pool with the callable list
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
initializer=_pyfunc_worker_init,
initargs=([self.per_batch_map],))
idx = 0
# Wrap per_batch_map into _PythonCallable
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool)
def __del__(self):
if hasattr(self, 'process_pool') and self.process_pool is not None:
self.process_pool.close()
class BatchInfo(cde.CBatchInfo):
"""
@ -2352,7 +2374,6 @@ class MapDataset(Dataset):
# CPP ops remain the same
iter_specific_operations.append(op)
self.operations = iter_specific_operations
self.hook = _ExceptHookHandler(self.process_pool)
def __del__(self):
if hasattr(self, 'process_pool') and self.process_pool is not None:

@ -538,7 +538,7 @@ def check_batch(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns, output_columns,
column_order, pad_info], param_dict = parse_user_args(method, *args, **kwargs)
column_order, pad_info, python_multiprocessing], param_dict = parse_user_args(method, *args, **kwargs)
if not (isinstance(batch_size, int) or (callable(batch_size))):
raise TypeError("batch_size should either be an int or a callable.")
@ -577,6 +577,9 @@ def check_batch(method):
if column_order is not None:
check_columns(column_order, "column_order")
if python_multiprocessing is not None:
type_check(python_multiprocessing, (bool,), "python_multiprocessing")
return method(self, *args, **kwargs)
return new_method

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save