!11499 delete useless params in pipeline parallel

From: @yangzhenzhang
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
pull/11499/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 6394ba3974

@ -79,5 +79,21 @@ py::dict GetAllreduceFusion(const FuncGraphPtr &graph) {
}
return dict;
}
// In pipeline parallel mode, many parameters are not used and need to be deleted
py::list GetParallelParameterNameList(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
py::list parallel_parameter_name_list;
std::vector<AnfNodePtr> graph_params = graph->parameters();
for (auto param : graph_params) {
auto param_ptr = std::static_pointer_cast<Parameter>(param);
MS_EXCEPTION_IF_NULL(param_ptr);
std::string name = param_ptr->name();
parallel_parameter_name_list.append(name);
}
return parallel_parameter_name_list;
}
} // namespace parallel
} // namespace mindspore

@ -26,6 +26,7 @@ namespace mindspore {
namespace parallel {
py::dict GetParameterLayout(const FuncGraphPtr &graph);
py::dict GetAllreduceFusion(const FuncGraphPtr &graph);
py::list GetParallelParameterNameList(const FuncGraphPtr &graph);
} // namespace parallel
} // namespace mindspore

@ -80,6 +80,8 @@ PYBIND11_MODULE(_c_expression, m) {
py::arg("params"), "Fetch the inputs of Conv or Matmul for quant export.")
.def("get_parameter_layout", &ExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"),
"Get Parameter Tensor Layout Dictionary.")
.def("get_parallel_parameter_name_list", &ExecutorPy::GetParallelParameterNameList,
py::arg("phase") = py::str("train"), "Get Parallel Parameter Name List.")
.def("get_strategy", &ExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"),
"Get CNode Strategy Dictionary.")
.def("get_num_parallel_ops", &ExecutorPy::GetNumOpsInfo, py::arg("phase") = py::str("train"),

@ -286,6 +286,12 @@ py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) {
return stra_dict_[phase];
}
py::list ExecutorPy::GetParallelParameterNameList(const std::string &phase) {
std::string param_graph = phase + kStepParallelGraph;
auto graph = GetFuncGraph(param_graph);
return mindspore::parallel::GetParallelParameterNameList(graph);
}
void ExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy) {
MS_LOG(DEBUG) << "SetCNodeStrategy!";
stra_dict_[phase_][py::str(name)] = strategy;

@ -93,6 +93,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
void PyExePath(const py::object &phase);
py::dict GetParameterLayout(const std::string &phase);
py::dict GetCNodeStrategy(const std::string &phase);
py::list GetParallelParameterNameList(const std::string &phase);
void SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy);
size_t GetNumOpsInfo(const std::string &phase);
void SetNumOpsInfo(size_t);

@ -27,7 +27,7 @@ from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, Pynativ
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
from ..parallel._ps_context import _is_role_pserver
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
_get_parameter_broadcast
_get_parameter_broadcast, _get_pipeline_stages
# store ms_function class compiled pipeline cache
ms_compile_cache = {}
@ -501,6 +501,9 @@ class _Executor:
if auto_parallel_mode:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
if _get_pipeline_stages() > 1:
obj.parallel_parameter_name_list = self._executor.get_parallel_parameter_name_list(phase)
obj.remove_redundant_parameters()
replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
if not enable_debug_runtime or enable_ge:
if auto_parallel_mode:

@ -89,6 +89,7 @@ class Cell(Cell_):
self._scope = None
self._phase = 'train'
self._parameter_layout_dict = {}
self._parallel_parameter_name_list = ()
self._create_time = int(time.time() * 1e9)
self.phase_prefix = ""
self.parameter_broadcast_done = False
@ -213,6 +214,16 @@ class Cell(Cell_):
raise TypeError("'parameter_layout_dict' must be dict type.")
self._parameter_layout_dict = value
@property
def parallel_parameter_name_list(self):
return self._parallel_parameter_name_list
@parallel_parameter_name_list.setter
def parallel_parameter_name_list(self, value):
if not isinstance(value, list):
raise TypeError("'parallel_parameter_name_list' must be list type.")
self._parallel_parameter_name_list = value
def get_func_graph_proto(self):
"""Return graph binary proto."""
return _executor._get_func_graph_proto(self, self.phase + "." + str(self.create_time), "anf_ir", True)
@ -656,6 +667,28 @@ class Cell(Cell_):
"""
return None
def remove_redundant_parameters(self):
"""Remove the redundant parameters"""
cells = self.cells_and_names()
for _, cell in cells:
params = cell._params.items()
for param_name, param in list(params):
if param.name not in self.parallel_parameter_name_list:
cell._params.pop(param_name)
logger.info("remove the redundant parameter: %s", param.name)
continue
cell_dict = cell.__dict__
for key in cell_dict:
if isinstance(cell_dict[key], ParameterTuple):
param_tuple = cell_dict[key]
new_param_tuple = []
for param in param_tuple:
if param.name not in self.parallel_parameter_name_list:
logger.info("remove the redundant parameter: %s in ParameterTuple", param.name)
continue
new_param_tuple.append(param)
cell.__dict__[key] = ParameterTuple(new_param_tuple)
def init_parameters_data(self, auto_parallel_mode=False):
"""
Initialize all parameters and replace the original saved parameters in cell.
@ -750,7 +783,7 @@ class Cell(Cell_):
"""
Returns all trainable parameters.
Returns a list of all trainable parmeters.
Returns a list of all trainable parameters.
Args:
recurse (bool): Whether contains the trainable parameters of subcells. Default: True.
@ -1031,7 +1064,7 @@ class Cell(Cell_):
Note:
fn must be defined as the following code. `cell_name` is the name of registered cell.
`grad_input` is gradient passed to the cell. `grad_output` is the gradient computed and passed to the
next cell or primitve, which may be modified and returned.
next cell or primitive, which may be modified and returned.
hook_fn(cell_name, grad_input, grad_output) -> Tensor or None.
Args:

@ -35,6 +35,10 @@ def _get_full_batch():
"""Get whether to use full_batch."""
return auto_parallel_context().get_full_batch()
def _get_pipeline_stages():
"""Get pipeline stages"""
return auto_parallel_context().get_pipeline_stages()
def _check_full_batch():
"""
full_batch could only be used under semi_auto_parallel or auto_parallel, check it.

@ -120,7 +120,9 @@ def test_pipeline_split_stage0():
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.1.param"
assert param.name != "cell.block.1.param1"
def test_pipeline_split_stage1():
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
@ -135,6 +137,9 @@ def test_pipeline_split_stage1():
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.0.param"
assert param.name != "cell.block.0.param1"
def test_pipeline_split_shared_parameter_stage0():

Loading…
Cancel
Save