|
|
|
@ -379,16 +379,15 @@ class _Executor:
|
|
|
|
|
auto_split_param_names = (param_name for param_name in auto_split_params)
|
|
|
|
|
return auto_split_param_names
|
|
|
|
|
|
|
|
|
|
def _build_broadcast_graph(self, obj, broadcast_params, broadcast_phase):
|
|
|
|
|
def _build_broadcast_graph(self, broadcast_params_dict, broadcast_phase):
|
|
|
|
|
"""Build broadcast graph."""
|
|
|
|
|
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
|
|
|
|
|
|
|
|
|
|
_broadcast_net = _BroadCastCell(broadcast_params)
|
|
|
|
|
_broadcast_net = _BroadCastCell(broadcast_params_dict.values())
|
|
|
|
|
_broadcast_net.phase = broadcast_phase
|
|
|
|
|
broadcasted_params = _broadcast_net()
|
|
|
|
|
parameters_broadcast_dict = obj.parameters_broadcast_dict()
|
|
|
|
|
for param_name, param in zip(parameters_broadcast_dict, broadcasted_params):
|
|
|
|
|
parameters_broadcast_dict[param_name].set_data(param)
|
|
|
|
|
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
|
|
|
|
|
broadcast_params_dict[param_name].set_data(param)
|
|
|
|
|
|
|
|
|
|
def _set_dataset_mode(self, args_list):
|
|
|
|
|
"""set dataset mode."""
|
|
|
|
@ -476,10 +475,15 @@ class _Executor:
|
|
|
|
|
auto_split_param_names = []
|
|
|
|
|
if auto_parallel_mode:
|
|
|
|
|
auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict)
|
|
|
|
|
broadcast_params = [param for param_name, param in obj.parameters_broadcast_dict().items() if
|
|
|
|
|
param_name not in auto_split_param_names]
|
|
|
|
|
broadcast_phase = "broadcast_subgraph" + "." + str(obj.create_time)
|
|
|
|
|
self._build_broadcast_graph(obj, broadcast_params, broadcast_phase)
|
|
|
|
|
|
|
|
|
|
broadcast_params_dict = obj.parameters_broadcast_dict()
|
|
|
|
|
if auto_split_param_names and broadcast_params_dict:
|
|
|
|
|
broadcast_params_dict = OrderedDict()
|
|
|
|
|
for param_name, param in obj.parameters_broadcast_dict().items():
|
|
|
|
|
if param_name not in auto_split_param_names:
|
|
|
|
|
broadcast_params_dict[param_name] = param
|
|
|
|
|
broadcast_phase = "_broadcast_subgraph" + "." + str(obj.create_time)
|
|
|
|
|
self._build_broadcast_graph(broadcast_params_dict, broadcast_phase)
|
|
|
|
|
self.compile_cache[phase] = broadcast_phase
|
|
|
|
|
|
|
|
|
|
return phase, True
|
|
|
|
|