From cbca482e599a3dfaabdefdb560af859019597393 Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Wed, 20 Jan 2021 14:51:22 +0800 Subject: [PATCH] delete useless parameter in pipeline parallel --- .../parallel/graph_util/get_parallel_info.cc | 16 ++++++++ .../parallel/graph_util/get_parallel_info.h | 1 + mindspore/ccsrc/pipeline/jit/init.cc | 2 + mindspore/ccsrc/pipeline/jit/pipeline.cc | 6 +++ mindspore/ccsrc/pipeline/jit/pipeline.h | 1 + mindspore/common/api.py | 5 ++- mindspore/nn/cell.py | 37 ++++++++++++++++++- mindspore/parallel/_utils.py | 4 ++ .../ut/python/parallel/test_pipeline_split.py | 7 +++- 9 files changed, 75 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc index fcac6abaf7..6c088485e4 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc @@ -79,5 +79,21 @@ py::dict GetAllreduceFusion(const FuncGraphPtr &graph) { } return dict; } + +// In pipeline parallel mode, many parameters are not used and need to be deleted +py::list GetParallelParameterNameList(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + + py::list parallel_parameter_name_list; + std::vector graph_params = graph->parameters(); + + for (auto param : graph_params) { + auto param_ptr = std::static_pointer_cast(param); + MS_EXCEPTION_IF_NULL(param_ptr); + std::string name = param_ptr->name(); + parallel_parameter_name_list.append(name); + } + return parallel_parameter_name_list; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h index ac86571e9b..f0dc6bcc44 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h +++ b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h @@ -26,6 +26,7 @@ namespace mindspore { namespace parallel { py::dict GetParameterLayout(const FuncGraphPtr &graph); py::dict GetAllreduceFusion(const FuncGraphPtr &graph); +py::list GetParallelParameterNameList(const FuncGraphPtr &graph); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 5d56c5b8fd..ac0c8d01bb 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -80,6 +80,8 @@ PYBIND11_MODULE(_c_expression, m) { py::arg("params"), "Fetch the inputs of Conv or Matmul for quant export.") .def("get_parameter_layout", &ExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"), "Get Parameter Tensor Layout Dictionary.") + .def("get_parallel_parameter_name_list", &ExecutorPy::GetParallelParameterNameList, + py::arg("phase") = py::str("train"), "Get Parallel Parameter Name List.") .def("get_strategy", &ExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"), "Get CNode Strategy Dictionary.") .def("get_num_parallel_ops", &ExecutorPy::GetNumOpsInfo, py::arg("phase") = py::str("train"), diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 56e58c2f9d..68d58bcfa4 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -286,6 +286,12 @@ py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) { return stra_dict_[phase]; } +py::list ExecutorPy::GetParallelParameterNameList(const std::string &phase) { + std::string param_graph = phase + kStepParallelGraph; + auto graph = GetFuncGraph(param_graph); + return mindspore::parallel::GetParallelParameterNameList(graph); +} + void ExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy) { MS_LOG(DEBUG) << "SetCNodeStrategy!"; stra_dict_[phase_][py::str(name)] = strategy; diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.h b/mindspore/ccsrc/pipeline/jit/pipeline.h index 5d40b5dc2a..2d0caba826 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.h +++ b/mindspore/ccsrc/pipeline/jit/pipeline.h @@ -93,6 +93,7 @@ class ExecutorPy : public std::enable_shared_from_this { void PyExePath(const py::object &phase); py::dict GetParameterLayout(const std::string &phase); py::dict GetCNodeStrategy(const std::string &phase); + py::list GetParallelParameterNameList(const std::string &phase); void SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy); size_t GetNumOpsInfo(const std::string &phase); void SetNumOpsInfo(size_t); diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 9ae087aa70..5dfe8fe092 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -27,7 +27,7 @@ from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, Pynativ from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline from ..parallel._ps_context import _is_role_pserver from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \ - _get_parameter_broadcast + _get_parameter_broadcast, _get_pipeline_stages # store ms_function class compiled pipeline cache ms_compile_cache = {} @@ -501,6 +501,9 @@ class _Executor: if auto_parallel_mode: obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) + if _get_pipeline_stages() > 1: + obj.parallel_parameter_name_list = self._executor.get_parallel_parameter_name_list(phase) + obj.remove_redundant_parameters() replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) if not enable_debug_runtime or enable_ge: if auto_parallel_mode: diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 442457b068..9191eed8e5 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -89,6 +89,7 @@ class Cell(Cell_): self._scope = None self._phase = 'train' self._parameter_layout_dict = {} + self._parallel_parameter_name_list = () self._create_time = int(time.time() * 1e9) self.phase_prefix = "" self.parameter_broadcast_done = False @@ -213,6 +214,16 @@ class Cell(Cell_): raise TypeError("'parameter_layout_dict' must be dict type.") self._parameter_layout_dict = value + @property + def parallel_parameter_name_list(self): + return self._parallel_parameter_name_list + + @parallel_parameter_name_list.setter + def parallel_parameter_name_list(self, value): + if not isinstance(value, list): + raise TypeError("'parallel_parameter_name_list' must be list type.") + self._parallel_parameter_name_list = value + def get_func_graph_proto(self): """Return graph binary proto.""" return _executor._get_func_graph_proto(self, self.phase + "." + str(self.create_time), "anf_ir", True) @@ -656,6 +667,28 @@ class Cell(Cell_): """ return None + def remove_redundant_parameters(self): + """Remove the redundant parameters""" + cells = self.cells_and_names() + for _, cell in cells: + params = cell._params.items() + for param_name, param in list(params): + if param.name not in self.parallel_parameter_name_list: + cell._params.pop(param_name) + logger.info("remove the redundant parameter: %s", param.name) + continue + cell_dict = cell.__dict__ + for key in cell_dict: + if isinstance(cell_dict[key], ParameterTuple): + param_tuple = cell_dict[key] + new_param_tuple = [] + for param in param_tuple: + if param.name not in self.parallel_parameter_name_list: + logger.info("remove the redundant parameter: %s in ParameterTuple", param.name) + continue + new_param_tuple.append(param) + cell.__dict__[key] = ParameterTuple(new_param_tuple) + def init_parameters_data(self, auto_parallel_mode=False): """ Initialize all parameters and replace the original saved parameters in cell. @@ -750,7 +783,7 @@ class Cell(Cell_): """ Returns all trainable parameters. - Returns a list of all trainable parmeters. + Returns a list of all trainable parameters. Args: recurse (bool): Whether contains the trainable parameters of subcells. Default: True. @@ -1031,7 +1064,7 @@ class Cell(Cell_): Note: fn must be defined as the following code. `cell_name` is the name of registered cell. `grad_input` is gradient passed to the cell. `grad_output` is the gradient computed and passed to the - next cell or primitve, which may be modified and returned. + next cell or primitive, which may be modified and returned. hook_fn(cell_name, grad_input, grad_output) -> Tensor or None. Args: diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index e602fa6a3a..8dc97fc8f2 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -35,6 +35,10 @@ def _get_full_batch(): """Get whether to use full_batch.""" return auto_parallel_context().get_full_batch() +def _get_pipeline_stages(): + """Get pipeline stages""" + return auto_parallel_context().get_pipeline_stages() + def _check_full_batch(): """ full_batch could only be used under semi_auto_parallel or auto_parallel, check it. diff --git a/tests/ut/python/parallel/test_pipeline_split.py b/tests/ut/python/parallel/test_pipeline_split.py index a77fe2f968..abc09fb44e 100644 --- a/tests/ut/python/parallel/test_pipeline_split.py +++ b/tests/ut/python/parallel/test_pipeline_split.py @@ -120,7 +120,9 @@ def test_pipeline_split_stage0(): optimizer = nn.Lamb(params, learning_rate=0.01) model = Model(net, optimizer=optimizer) model.train(2, dataset, dataset_sink_mode=False) - + for _, param in model._train_network.parameters_and_names(): + assert param.name != "cell.block.1.param" + assert param.name != "cell.block.1.param1" def test_pipeline_split_stage1(): context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2) @@ -135,6 +137,9 @@ def test_pipeline_split_stage1(): optimizer = nn.Lamb(params, learning_rate=0.01) model = Model(net, optimizer=optimizer) model.train(2, dataset, dataset_sink_mode=False) + for _, param in model._train_network.parameters_and_names(): + assert param.name != "cell.block.0.param" + assert param.name != "cell.block.0.param1" def test_pipeline_split_shared_parameter_stage0():