|
|
|
@ -255,7 +255,6 @@ def ms_function(fn=None, obj=None, input_signature=None):
|
|
|
|
|
process_obj = obj
|
|
|
|
|
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
|
|
|
|
|
process_obj = args[0]
|
|
|
|
|
args = (x.default_input if hasattr(x, 'default_input') else x for x in args)
|
|
|
|
|
return _MindSporeFunction(func, input_signature, process_obj)(*args)
|
|
|
|
|
|
|
|
|
|
return staging_specialize
|
|
|
|
@ -354,28 +353,8 @@ class _Executor:
|
|
|
|
|
raise RuntimeError("Failure to init and dataset subgraph!")
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def _build_data_graph(self, obj, params, phase):
|
|
|
|
|
if params is None:
|
|
|
|
|
self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
|
|
|
|
|
elif isinstance(params, OrderedDict):
|
|
|
|
|
self._executor.build_data_graph(params, phase)
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError('Parameters need OrderedDict type, but got {}'.
|
|
|
|
|
format(type(params)))
|
|
|
|
|
|
|
|
|
|
def _params_init_data(self, obj, params, auto_parallel_mode=False):
|
|
|
|
|
"""Init parameters' data."""
|
|
|
|
|
if params is not None:
|
|
|
|
|
for key, param in params.items():
|
|
|
|
|
if not auto_parallel_mode:
|
|
|
|
|
param.init_data()
|
|
|
|
|
elif key not in obj.parameter_layout_dict:
|
|
|
|
|
logger.debug("Layout dict does not contain the key %s.", key)
|
|
|
|
|
param.init_data(set_sliced=True)
|
|
|
|
|
else:
|
|
|
|
|
layout = obj.parameter_layout_dict[key]
|
|
|
|
|
param.init_data(layout, set_sliced=True)
|
|
|
|
|
obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
|
|
|
|
|
def _build_data_graph(self, obj, phase):
|
|
|
|
|
self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
|
|
|
|
|
|
|
|
|
|
def _set_dataset_mode(self, args_list):
|
|
|
|
|
"""set dataset mode."""
|
|
|
|
@ -386,7 +365,7 @@ class _Executor:
|
|
|
|
|
else:
|
|
|
|
|
_set_dataset_mode_config('normal')
|
|
|
|
|
|
|
|
|
|
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
|
|
|
|
|
def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False):
|
|
|
|
|
"""
|
|
|
|
|
Compiles graph.
|
|
|
|
|
|
|
|
|
@ -394,7 +373,6 @@ class _Executor:
|
|
|
|
|
obj (Function/Cell): The function or cell instance need compile.
|
|
|
|
|
args (tuple): Function or cell input arguments.
|
|
|
|
|
phase (str): The name of compile phase. Default: 'predict'.
|
|
|
|
|
params (OrderedDict): The parameters dictionary used for init data graph. Default: None.
|
|
|
|
|
do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
|
|
|
|
|
auto_parallel_mode: When set to True, use auto parallel mode to compile graph.
|
|
|
|
|
|
|
|
|
@ -435,10 +413,8 @@ class _Executor:
|
|
|
|
|
|
|
|
|
|
if auto_parallel_mode:
|
|
|
|
|
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
|
|
|
|
self._params_init_data(obj, params, auto_parallel_mode)
|
|
|
|
|
if not enable_debug_runtime or enable_ge:
|
|
|
|
|
if auto_parallel_mode:
|
|
|
|
|
obj.load_parameter_slice(params)
|
|
|
|
|
replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
|
|
|
|
|
self._updata_param_node_default_input(phase, replace)
|
|
|
|
|
|
|
|
|
|
# set parallel inputs in sink mode
|
|
|
|
|
if auto_parallel_mode and (args and isinstance(args[0], Tensor) and args[0].virtual_flag):
|
|
|
|
@ -446,16 +422,20 @@ class _Executor:
|
|
|
|
|
|
|
|
|
|
# the following GE init process is not needed when use vm or ms backend
|
|
|
|
|
if enable_ge:
|
|
|
|
|
self._build_data_graph(obj, params, phase)
|
|
|
|
|
self._build_data_graph(obj, phase)
|
|
|
|
|
|
|
|
|
|
if "export" not in phase:
|
|
|
|
|
init_phase = "init_subgraph" + "." + str(obj.create_time)
|
|
|
|
|
_exec_init_graph(obj, init_phase)
|
|
|
|
|
elif not enable_ge and "export" in phase:
|
|
|
|
|
self._build_data_graph(obj, params, phase)
|
|
|
|
|
self._build_data_graph(obj, phase)
|
|
|
|
|
|
|
|
|
|
return phase, True
|
|
|
|
|
|
|
|
|
|
def _updata_param_node_default_input(self, phase, replace):
|
|
|
|
|
new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
|
|
|
|
|
return self._executor.updata_param_node_default_input(phase, new_param)
|
|
|
|
|
|
|
|
|
|
def _get_strategy(self, obj):
|
|
|
|
|
real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
|
|
|
|
|
return self._executor.get_strategy(real_phase)
|
|
|
|
|