|
|
|
@ -22,7 +22,7 @@ from te.platform.cce_conf import te_set_version
|
|
|
|
|
from te.platform.fusion_manager import set_current_op_name
|
|
|
|
|
from te.platform.fusion_util import fusion_op, dump_fusion_json
|
|
|
|
|
from te.platform.parallel_compilation import init_multi_process_env, get_finished_compilation_task, \
|
|
|
|
|
deinit_multi_process_env, dispatch_autotune_task, start_ga_multi_process
|
|
|
|
|
deinit_multi_process_env, dispatch_autotune_task, start_ga_multi_process, import_py_module
|
|
|
|
|
import auto_tune
|
|
|
|
|
from schedule_search.rl_online_tune import rl_tune_init, dispatch_fusion_tune_task, dispatch_single_tune_task, \
|
|
|
|
|
rl_tune_deinit
|
|
|
|
@ -48,6 +48,8 @@ class TbeTuner:
|
|
|
|
|
if os.environ.get("TUNE_DUMP_PATH") is not None:
|
|
|
|
|
self.offline_dump_path = os.getenv("TUNE_DUMP_PATH", "")
|
|
|
|
|
self._creating_custom_path(tune_mode)
|
|
|
|
|
self.fusion_need_sync = 0
|
|
|
|
|
self.module_list = {}
|
|
|
|
|
|
|
|
|
|
def init_tune_interface(self, json_str, process_num):
|
|
|
|
|
"""
|
|
|
|
@ -222,6 +224,24 @@ class TbeTuner:
|
|
|
|
|
log.info("GA Tune init success.")
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def sync_fusion_env(self):
|
|
|
|
|
"""
|
|
|
|
|
Sync fusion env
|
|
|
|
|
:return: None
|
|
|
|
|
"""
|
|
|
|
|
if self.fusion_need_sync == 0:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
module_using = []
|
|
|
|
|
for key, value in self.module_list.items():
|
|
|
|
|
if value > 0:
|
|
|
|
|
module_using.append(str(key))
|
|
|
|
|
self.module_list[key] = 0
|
|
|
|
|
|
|
|
|
|
module_str = ",".join(module_using)
|
|
|
|
|
import_py_module(module_str)
|
|
|
|
|
self.fusion_need_sync = 0
|
|
|
|
|
|
|
|
|
|
def rl_tune(self, task_id, op_json):
|
|
|
|
|
"""
|
|
|
|
|
RL tune for single op and fusion op
|
|
|
|
@ -231,6 +251,7 @@ class TbeTuner:
|
|
|
|
|
"""
|
|
|
|
|
json_info = json.loads(op_json)
|
|
|
|
|
if "fusion_op" in json_info:
|
|
|
|
|
self.sync_fusion_env()
|
|
|
|
|
ret = self.fusion_rl_tune(task_id, json_info)
|
|
|
|
|
else:
|
|
|
|
|
ret = self.single_rl_tune(task_id, json_info)
|
|
|
|
@ -244,6 +265,7 @@ class TbeTuner:
|
|
|
|
|
"""
|
|
|
|
|
json_info = json.loads(op_json)
|
|
|
|
|
if "fusion_op" in json_info:
|
|
|
|
|
self.sync_fusion_env()
|
|
|
|
|
self.fusion_ga_tune(task_id, json_info)
|
|
|
|
|
else:
|
|
|
|
|
self.single_ga_tune(task_id, json_info)
|
|
|
|
@ -289,6 +311,9 @@ class TbeTuner:
|
|
|
|
|
l1size = 0 # todo need to verify
|
|
|
|
|
ret = dispatch_single_tune_task(graph_id, task_id, l1size, base_kernel, kernel_name, op_module_name,
|
|
|
|
|
op_module_name + "@" + op_module_name, op_type, op_type, op_args)
|
|
|
|
|
|
|
|
|
|
self.module_list[op_module_name] = 1
|
|
|
|
|
self.fusion_need_sync += 1
|
|
|
|
|
return ret, job_type
|
|
|
|
|
|
|
|
|
|
def get_op_module_names(self, json_info):
|
|
|
|
|