|
|
|
@ -400,14 +400,14 @@ class BuildExtension(build_ext, object):
|
|
|
|
|
# ncvv compile CUDA source
|
|
|
|
|
if is_cuda_file(src):
|
|
|
|
|
if core.is_compiled_with_rocm():
|
|
|
|
|
assert ROCM_HOME is not None
|
|
|
|
|
assert ROCM_HOME is not None, "Not found ROCM runtime, please use `export ROCM_PATH= XXX` to specific it."
|
|
|
|
|
hipcc_cmd = os.path.join(ROCM_HOME, 'bin', 'hipcc')
|
|
|
|
|
self.compiler.set_executable('compiler_so', hipcc_cmd)
|
|
|
|
|
# {'nvcc': {}, 'cxx: {}}
|
|
|
|
|
if isinstance(cflags, dict):
|
|
|
|
|
cflags = cflags['hipcc']
|
|
|
|
|
else:
|
|
|
|
|
assert CUDA_HOME is not None
|
|
|
|
|
assert CUDA_HOME is not None, "Not found CUDA runtime, please use `export CUDA_HOME= XXX` to specific it."
|
|
|
|
|
nvcc_cmd = os.path.join(CUDA_HOME, 'bin', 'nvcc')
|
|
|
|
|
self.compiler.set_executable('compiler_so', nvcc_cmd)
|
|
|
|
|
# {'nvcc': {}, 'cxx: {}}
|
|
|
|
@ -470,7 +470,7 @@ class BuildExtension(build_ext, object):
|
|
|
|
|
src = src_list[0]
|
|
|
|
|
obj = obj_list[0]
|
|
|
|
|
if is_cuda_file(src):
|
|
|
|
|
assert CUDA_HOME is not None
|
|
|
|
|
assert CUDA_HOME is not None, "Not found CUDA runtime, please use `export CUDA_HOME= XXX` to specific it."
|
|
|
|
|
nvcc_cmd = os.path.join(CUDA_HOME, 'bin', 'nvcc')
|
|
|
|
|
if isinstance(self.cflags, dict):
|
|
|
|
|
cflags = self.cflags['nvcc']
|
|
|
|
|