!3935 Decouple ME and AKG for Ascend
Merge pull request !3935 from ZhangQinghua/masterpull/3935/MERGE
commit
0154bdeb70
@ -0,0 +1,88 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""akg process"""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
def _compile_akg_task(*json_strs):
|
||||
"""
|
||||
compile func called in single process
|
||||
|
||||
Parameters:
|
||||
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
|
||||
"""
|
||||
akg_compiler = os.path.join(os.path.split(
|
||||
os.path.realpath(__file__))[0], "compiler.py")
|
||||
for json_str in json_strs:
|
||||
res = subprocess.run(
|
||||
[sys.executable, akg_compiler, json_str], text=True)
|
||||
if res.returncode != 0:
|
||||
raise ValueError("Failed, args: {}!".format(json_str))
|
||||
|
||||
def create_akg_parallel_process(process_num, wait_time):
|
||||
"""
|
||||
create AkgParallelCompiler object
|
||||
|
||||
Returns:
|
||||
AkgParallelCompiler
|
||||
"""
|
||||
return AkgProcess(process_num, wait_time)
|
||||
|
||||
class AkgProcess:
|
||||
"""akg kernel parallel process"""
|
||||
|
||||
def __init__(self, process_num, wait_time):
|
||||
"""
|
||||
Args:
|
||||
process_num: int. processes number
|
||||
waittime: int. max time the function blocked
|
||||
"""
|
||||
if not isinstance(process_num, int):
|
||||
raise ValueError("process number must be a num")
|
||||
if not isinstance(wait_time, int):
|
||||
raise ValueError("wait time must be a num")
|
||||
if process_num == 0:
|
||||
process_num = 1
|
||||
max_proc_num = 16
|
||||
self.process_num = min([cpu_count(), max_proc_num, process_num])
|
||||
self.args = [[] for _ in range(self.process_num)]
|
||||
self.wait_time = wait_time
|
||||
self.argc = 0
|
||||
|
||||
def compile(self):
|
||||
"""
|
||||
compile kernel by multi processes
|
||||
Return:
|
||||
True for all compile success, False for some failed.
|
||||
"""
|
||||
if self.argc == 0:
|
||||
raise ValueError("json must be not null")
|
||||
with Pool(processes=self.process_num) as pool:
|
||||
res = pool.starmap_async(_compile_akg_task, self.args)
|
||||
res.get(timeout=self.wait_time)
|
||||
return True
|
||||
|
||||
def accept_json(self, json):
|
||||
"""
|
||||
accept json data before compile
|
||||
Args:
|
||||
json: str. kernel info.
|
||||
"""
|
||||
if not isinstance(json, str):
|
||||
raise ValueError("json must be a str")
|
||||
self.args[self.argc % self.process_num].append(json)
|
||||
self.argc += 1
|
@ -1,71 +0,0 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing multi process compile with json"""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
|
||||
def _compile_akg_task(*json_strs):
|
||||
"""
|
||||
compile func called in single process
|
||||
|
||||
Parameters:
|
||||
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
|
||||
"""
|
||||
akg_compiler = os.path.join(os.path.split(
|
||||
os.path.realpath(__file__))[0], "compiler.py")
|
||||
for json_str in json_strs:
|
||||
res = subprocess.run(
|
||||
[sys.executable, akg_compiler, json_str], text=True)
|
||||
if res.returncode != 0:
|
||||
raise ValueError("Failed, args: {}!".format(json_str))
|
||||
|
||||
|
||||
def compile_akg_kernel_parallel(json_infos, process, waitime):
|
||||
"""
|
||||
compile kernel use multi processes
|
||||
|
||||
Parameters:
|
||||
json_infos: list. list contain kernel info(task id and json str)
|
||||
process: int. processes num
|
||||
waittime: int. max time the function blocked
|
||||
|
||||
Returns:
|
||||
True for all compile success, False for some failed.
|
||||
"""
|
||||
if not isinstance(json_infos, list):
|
||||
raise ValueError("json_infos must be a list")
|
||||
if not isinstance(process, int):
|
||||
raise ValueError("process must be a num")
|
||||
if not isinstance(waitime, int):
|
||||
raise ValueError("waittime must be a num")
|
||||
|
||||
if process == 0 and json_infos:
|
||||
process = 1
|
||||
|
||||
cpu_proc_num = cpu_count()
|
||||
max_proc_num = 16
|
||||
process = min([cpu_proc_num, max_proc_num, process])
|
||||
|
||||
args = [[] for _ in range(process)]
|
||||
for p, info in enumerate(json_infos):
|
||||
args[p % process].append(info)
|
||||
|
||||
with Pool(processes=process) as pool:
|
||||
res = pool.starmap_async(_compile_akg_task, args)
|
||||
res.get(timeout=waitime)
|
||||
return True
|
Loading…
Reference in new issue