|
|
|
@ -18,16 +18,21 @@
|
|
|
|
|
import types
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from functools import wraps
|
|
|
|
|
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from .tensor import Tensor as MsTensor
|
|
|
|
|
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_
|
|
|
|
|
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
|
|
|
|
|
from .tensor import Tensor as MsTensor
|
|
|
|
|
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor
|
|
|
|
|
from ..parallel._ps_context import _is_role_pserver
|
|
|
|
|
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
|
|
|
|
|
_get_parameter_broadcast
|
|
|
|
|
|
|
|
|
|
# store ms_function class compiled pipeline cache
|
|
|
|
|
ms_compile_cache = {}
|
|
|
|
|
|
|
|
|
|
BROADCAST_PHASE = "_broadcast_"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_function_arguments(fn, *args):
|
|
|
|
|
"""
|
|
|
|
@ -362,6 +367,27 @@ class _Executor:
|
|
|
|
|
def _build_data_graph(self, obj, phase):
|
|
|
|
|
self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
|
|
|
|
|
|
|
|
|
|
def _get_auto_split_param_names(self, parameter_layout_dict):
|
|
|
|
|
auto_split_params = {}
|
|
|
|
|
for key, value in parameter_layout_dict.items():
|
|
|
|
|
for dim in value[1]:
|
|
|
|
|
if dim != -1:
|
|
|
|
|
auto_split_params[key] = value
|
|
|
|
|
break
|
|
|
|
|
auto_split_param_names = (param_name for param_name in auto_split_params)
|
|
|
|
|
return auto_split_param_names
|
|
|
|
|
|
|
|
|
|
def _build_broadcast_graph(self, obj, broadcast_params, broadcast_phase):
|
|
|
|
|
"""Build broadcast graph."""
|
|
|
|
|
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
|
|
|
|
|
|
|
|
|
|
_broadcast_net = _BroadCastCell(broadcast_params)
|
|
|
|
|
_broadcast_net.phase = broadcast_phase
|
|
|
|
|
broadcasted_params = _broadcast_net()
|
|
|
|
|
parameters_broadcast_dict = obj.parameters_broadcast_dict()
|
|
|
|
|
for param_name, param in zip(parameters_broadcast_dict, broadcasted_params):
|
|
|
|
|
parameters_broadcast_dict[param_name].set_data(param)
|
|
|
|
|
|
|
|
|
|
def _set_dataset_mode(self, args_list):
|
|
|
|
|
"""set dataset mode."""
|
|
|
|
|
# decide whether to sink based on whether the inputs is virtual or args_list is ()
|
|
|
|
@ -444,6 +470,15 @@ class _Executor:
|
|
|
|
|
_exec_init_graph(obj, init_phase)
|
|
|
|
|
elif not enable_ge and "export" in phase:
|
|
|
|
|
self._build_data_graph(obj, phase)
|
|
|
|
|
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
|
|
|
|
|
auto_split_param_names = []
|
|
|
|
|
if auto_parallel_mode:
|
|
|
|
|
auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict)
|
|
|
|
|
broadcast_params = [param for param_name, param in obj.parameters_broadcast_dict().items() if
|
|
|
|
|
param_name not in auto_split_param_names]
|
|
|
|
|
broadcast_phase = "broadcast_subgraph" + "." + str(obj.create_time)
|
|
|
|
|
self._build_broadcast_graph(obj, broadcast_params, broadcast_phase)
|
|
|
|
|
self.compile_cache[phase] = broadcast_phase
|
|
|
|
|
|
|
|
|
|
return phase, True
|
|
|
|
|
|
|
|
|
|