From: @jinyaohui
Reviewed-by: @kingxian,@zh_qh
Signed-off-by: @kingxian
pull/8993/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 1752964365

@ -379,16 +379,15 @@ class _Executor:
auto_split_param_names = (param_name for param_name in auto_split_params)
return auto_split_param_names
def _build_broadcast_graph(self, obj, broadcast_params, broadcast_phase):
def _build_broadcast_graph(self, broadcast_params_dict, broadcast_phase):
"""Build broadcast graph."""
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
_broadcast_net = _BroadCastCell(broadcast_params)
_broadcast_net = _BroadCastCell(broadcast_params_dict.values())
_broadcast_net.phase = broadcast_phase
broadcasted_params = _broadcast_net()
parameters_broadcast_dict = obj.parameters_broadcast_dict()
for param_name, param in zip(parameters_broadcast_dict, broadcasted_params):
parameters_broadcast_dict[param_name].set_data(param)
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
broadcast_params_dict[param_name].set_data(param)
def _set_dataset_mode(self, args_list):
"""set dataset mode."""
@ -476,10 +475,15 @@ class _Executor:
auto_split_param_names = []
if auto_parallel_mode:
auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict)
broadcast_params = [param for param_name, param in obj.parameters_broadcast_dict().items() if
param_name not in auto_split_param_names]
broadcast_phase = "broadcast_subgraph" + "." + str(obj.create_time)
self._build_broadcast_graph(obj, broadcast_params, broadcast_phase)
broadcast_params_dict = obj.parameters_broadcast_dict()
if auto_split_param_names and broadcast_params_dict:
broadcast_params_dict = OrderedDict()
for param_name, param in obj.parameters_broadcast_dict().items():
if param_name not in auto_split_param_names:
broadcast_params_dict[param_name] = param
broadcast_phase = "_broadcast_subgraph" + "." + str(obj.create_time)
self._build_broadcast_graph(broadcast_params_dict, broadcast_phase)
self.compile_cache[phase] = broadcast_phase
return phase, True

@ -564,7 +564,6 @@ class Model:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
_device_number_check(self._parallel_mode, self._device_number)
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
self._train(epoch,
train_dataset,

Loading…
Cancel
Save