|
|
@ -19,10 +19,10 @@ import types
|
|
|
|
import typing
|
|
|
|
import typing
|
|
|
|
import logging
|
|
|
|
import logging
|
|
|
|
import traceback
|
|
|
|
import traceback
|
|
|
|
import akg.tvm
|
|
|
|
import _akg.tvm
|
|
|
|
import akg
|
|
|
|
import _akg
|
|
|
|
from akg import save_gpu_param as gpu_utils
|
|
|
|
from _akg import save_gpu_param as gpu_utils
|
|
|
|
from akg.utils import validation_check as vc_util
|
|
|
|
from _akg.utils import validation_check as vc_util
|
|
|
|
|
|
|
|
|
|
|
|
MS_CUDA_KERNEL_PATH = "/tmp/cuda_meta/"
|
|
|
|
MS_CUDA_KERNEL_PATH = "/tmp/cuda_meta/"
|
|
|
|
|
|
|
|
|
|
|
@ -38,21 +38,21 @@ def op_build(opnames, computes, args, custom_schedule, device, kernel_name, attr
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
schedule_name = 'gpu_schedule_' + opnames[0]
|
|
|
|
schedule_name = 'gpu_schedule_' + opnames[0]
|
|
|
|
schedule_func = getattr(akg.gpu, schedule_name)
|
|
|
|
schedule_func = getattr(_akg.gpu, schedule_name)
|
|
|
|
if not isinstance(schedule_func, (types.FunctionType, typing.Callable)):
|
|
|
|
if not isinstance(schedule_func, (types.FunctionType, typing.Callable)):
|
|
|
|
logging.error("no schedule func found %s", str(schedule_name))
|
|
|
|
logging.error("no schedule func found %s", str(schedule_name))
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
ptx_file = os.path.realpath(MS_CUDA_KERNEL_PATH + kernel_name + ".ptx")
|
|
|
|
ptx_file = os.path.realpath(MS_CUDA_KERNEL_PATH + kernel_name + ".ptx")
|
|
|
|
if os.path.exists(ptx_file):
|
|
|
|
if os.path.exists(ptx_file):
|
|
|
|
os.remove(ptx_file)
|
|
|
|
os.chmod(ptx_file, 0o600)
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
with open(ptx_file, 'at') as file:
|
|
|
|
with open(ptx_file, 'at') as file:
|
|
|
|
fcntl.flock(file.fileno(), fcntl.LOCK_EX)
|
|
|
|
fcntl.flock(file.fileno(), fcntl.LOCK_EX)
|
|
|
|
file.seek(0, 2)
|
|
|
|
file.seek(0, 2)
|
|
|
|
if file.tell() == 0:
|
|
|
|
if file.tell() == 0:
|
|
|
|
s = schedule_func(computes)
|
|
|
|
s = schedule_func(computes)
|
|
|
|
foo = akg.tvm.build(s, args, device, name=kernel_name)
|
|
|
|
foo = _akg.tvm.build(s, args, device, name=kernel_name)
|
|
|
|
ptx_code = foo.imported_modules[0].get_source("ptx")
|
|
|
|
ptx_code = foo.imported_modules[0].get_source("ptx")
|
|
|
|
file.write(ptx_code)
|
|
|
|
file.write(ptx_code)
|
|
|
|
json_file = os.path.realpath(MS_CUDA_KERNEL_PATH + kernel_name + ".json")
|
|
|
|
json_file = os.path.realpath(MS_CUDA_KERNEL_PATH + kernel_name + ".json")
|