From 4f8de10897129864e142b38f121a346cd4c27174 Mon Sep 17 00:00:00 2001 From: jinyaohui Date: Wed, 25 Nov 2020 11:33:36 +0800 Subject: [PATCH] fix bug --- mindspore/common/api.py | 22 +++++++++++++--------- mindspore/train/model.py | 1 - 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 265361c3de..02e9876ede 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -379,16 +379,15 @@ class _Executor: auto_split_param_names = (param_name for param_name in auto_split_params) return auto_split_param_names - def _build_broadcast_graph(self, obj, broadcast_params, broadcast_phase): + def _build_broadcast_graph(self, broadcast_params_dict, broadcast_phase): """Build broadcast graph.""" from mindspore.nn.wrap.cell_wrapper import _BroadCastCell - _broadcast_net = _BroadCastCell(broadcast_params) + _broadcast_net = _BroadCastCell(broadcast_params_dict.values()) _broadcast_net.phase = broadcast_phase broadcasted_params = _broadcast_net() - parameters_broadcast_dict = obj.parameters_broadcast_dict() - for param_name, param in zip(parameters_broadcast_dict, broadcasted_params): - parameters_broadcast_dict[param_name].set_data(param) + for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params): + broadcast_params_dict[param_name].set_data(param) def _set_dataset_mode(self, args_list): """set dataset mode.""" @@ -476,10 +475,15 @@ class _Executor: auto_split_param_names = [] if auto_parallel_mode: auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict) - broadcast_params = [param for param_name, param in obj.parameters_broadcast_dict().items() if - param_name not in auto_split_param_names] - broadcast_phase = "broadcast_subgraph" + "." + str(obj.create_time) - self._build_broadcast_graph(obj, broadcast_params, broadcast_phase) + + broadcast_params_dict = obj.parameters_broadcast_dict() + if auto_split_param_names and broadcast_params_dict: + broadcast_params_dict = OrderedDict() + for param_name, param in obj.parameters_broadcast_dict().items(): + if param_name not in auto_split_param_names: + broadcast_params_dict[param_name] = param + broadcast_phase = "_broadcast_subgraph" + "." + str(obj.create_time) + self._build_broadcast_graph(broadcast_params_dict, broadcast_phase) self.compile_cache[phase] = broadcast_phase return phase, True diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 5cd2ab522f..95271406f1 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -563,7 +563,6 @@ class Model: raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) _device_number_check(self._parallel_mode, self._device_number) - _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) self._train(epoch, train_dataset,