|
|
@ -337,7 +337,6 @@ class _Executor:
|
|
|
|
self.is_init = False
|
|
|
|
self.is_init = False
|
|
|
|
self._executor = Executor_.get_instance()
|
|
|
|
self._executor = Executor_.get_instance()
|
|
|
|
self.compile_cache = {}
|
|
|
|
self.compile_cache = {}
|
|
|
|
self.phase_prefix = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
|
|
|
|
def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
|
|
|
|
input_indexs, phase='dataset'):
|
|
|
|
input_indexs, phase='dataset'):
|
|
|
@ -383,7 +382,12 @@ class _Executor:
|
|
|
|
"""Build broadcast graph."""
|
|
|
|
"""Build broadcast graph."""
|
|
|
|
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
|
|
|
|
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
|
|
|
|
|
|
|
|
|
|
|
|
_broadcast_net = _BroadCastCell(broadcast_params_dict.values())
|
|
|
|
if not broadcast_params_dict:
|
|
|
|
|
|
|
|
broadcast_params_dict = {}
|
|
|
|
|
|
|
|
broadcast_params = []
|
|
|
|
|
|
|
|
for param in broadcast_params_dict.values():
|
|
|
|
|
|
|
|
broadcast_params.append(Tensor(param.asnumpy()))
|
|
|
|
|
|
|
|
_broadcast_net = _BroadCastCell(broadcast_params)
|
|
|
|
_broadcast_net.phase = broadcast_phase
|
|
|
|
_broadcast_net.phase = broadcast_phase
|
|
|
|
broadcasted_params = _broadcast_net()
|
|
|
|
broadcasted_params = _broadcast_net()
|
|
|
|
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
|
|
|
|
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
|
|
|
@ -440,11 +444,11 @@ class _Executor:
|
|
|
|
if not hasattr(obj, "inputs_to_attr"):
|
|
|
|
if not hasattr(obj, "inputs_to_attr"):
|
|
|
|
dic = dict(zip(args_names, args_list))
|
|
|
|
dic = dict(zip(args_names, args_list))
|
|
|
|
key = generate_key(phase, dic)
|
|
|
|
key = generate_key(phase, dic)
|
|
|
|
self.phase_prefix = str(key[1])
|
|
|
|
obj.phase_prefix = str(key[1])
|
|
|
|
if 'export' in phase:
|
|
|
|
if 'export' in phase:
|
|
|
|
phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time)
|
|
|
|
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
phase = self.phase_prefix + phase + '.' + str(obj.create_time)
|
|
|
|
phase = obj.phase_prefix + phase + '.' + str(obj.create_time)
|
|
|
|
|
|
|
|
|
|
|
|
if phase in self.compile_cache.keys():
|
|
|
|
if phase in self.compile_cache.keys():
|
|
|
|
logger.debug("%r graph has existed.", phase)
|
|
|
|
logger.debug("%r graph has existed.", phase)
|
|
|
@ -518,9 +522,8 @@ class _Executor:
|
|
|
|
for param_name, param in obj.parameters_broadcast_dict().items():
|
|
|
|
for param_name, param in obj.parameters_broadcast_dict().items():
|
|
|
|
if param_name not in auto_split_param_names:
|
|
|
|
if param_name not in auto_split_param_names:
|
|
|
|
broadcast_params_dict[param_name] = param
|
|
|
|
broadcast_params_dict[param_name] = param
|
|
|
|
broadcast_phase = "_broadcast_subgraph" + "." + str(obj.create_time)
|
|
|
|
broadcast_phase = "_broadcast_subgraph"
|
|
|
|
self._build_broadcast_graph(broadcast_params_dict, broadcast_phase)
|
|
|
|
self._build_broadcast_graph(broadcast_params_dict, broadcast_phase)
|
|
|
|
self.compile_cache[phase] = broadcast_phase
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return phase, True
|
|
|
|
return phase, True
|
|
|
|
|
|
|
|
|
|
|
@ -529,15 +532,15 @@ class _Executor:
|
|
|
|
return self._executor.updata_param_node_default_input(phase, new_param)
|
|
|
|
return self._executor.updata_param_node_default_input(phase, new_param)
|
|
|
|
|
|
|
|
|
|
|
|
def _get_shard_strategy(self, obj):
|
|
|
|
def _get_shard_strategy(self, obj):
|
|
|
|
real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
|
|
|
|
real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time)
|
|
|
|
return self._executor.get_strategy(real_phase)
|
|
|
|
return self._executor.get_strategy(real_phase)
|
|
|
|
|
|
|
|
|
|
|
|
def _get_num_parallel_ops(self, obj):
|
|
|
|
def _get_num_parallel_ops(self, obj):
|
|
|
|
real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
|
|
|
|
real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time)
|
|
|
|
return self._executor.get_num_parallel_ops(real_phase)
|
|
|
|
return self._executor.get_num_parallel_ops(real_phase)
|
|
|
|
|
|
|
|
|
|
|
|
def _get_allreduce_fusion(self, obj):
|
|
|
|
def _get_allreduce_fusion(self, obj):
|
|
|
|
real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
|
|
|
|
real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time)
|
|
|
|
return self._executor.get_allreduce_fusion(real_phase)
|
|
|
|
return self._executor.get_allreduce_fusion(real_phase)
|
|
|
|
|
|
|
|
|
|
|
|
def has_compiled(self, phase='predict'):
|
|
|
|
def has_compiled(self, phase='predict'):
|
|
|
@ -581,7 +584,7 @@ class _Executor:
|
|
|
|
if phase == 'save':
|
|
|
|
if phase == 'save':
|
|
|
|
return self._executor((), phase + '.' + str(obj.create_time))
|
|
|
|
return self._executor((), phase + '.' + str(obj.create_time))
|
|
|
|
|
|
|
|
|
|
|
|
phase_real = self.phase_prefix + phase + '.' + str(obj.create_time)
|
|
|
|
phase_real = obj.phase_prefix + phase + '.' + str(obj.create_time)
|
|
|
|
if self.has_compiled(phase_real):
|
|
|
|
if self.has_compiled(phase_real):
|
|
|
|
return self._exec_pip(obj, *args, phase=phase_real)
|
|
|
|
return self._exec_pip(obj, *args, phase=phase_real)
|
|
|
|
raise KeyError('{} graph is not exist.'.format(phase_real))
|
|
|
|
raise KeyError('{} graph is not exist.'.format(phase_real))
|
|
|
@ -589,10 +592,10 @@ class _Executor:
|
|
|
|
def del_net_res(self, net_id):
|
|
|
|
def del_net_res(self, net_id):
|
|
|
|
self._executor.del_net_res(net_id)
|
|
|
|
self._executor.del_net_res(net_id)
|
|
|
|
|
|
|
|
|
|
|
|
def _get_func_graph_proto(self, exec_id, ir_type="onnx_ir", use_prefix=False):
|
|
|
|
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False):
|
|
|
|
"""Get graph proto from pipeline."""
|
|
|
|
"""Get graph proto from pipeline."""
|
|
|
|
if use_prefix:
|
|
|
|
if use_prefix:
|
|
|
|
exec_id = self.phase_prefix + exec_id
|
|
|
|
exec_id = obj.phase_prefix + exec_id
|
|
|
|
if self._executor.has_compiled(exec_id) is False:
|
|
|
|
if self._executor.has_compiled(exec_id) is False:
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
return self._executor.get_func_graph_proto(exec_id, ir_type)
|
|
|
|
return self._executor.get_func_graph_proto(exec_id, ir_type)
|
|
|
|