[CustomOp]Avoid raising warning while import paddle (#31804)

develop
Aurelius84 4 years ago committed by GitHub
parent 84a551380e
commit f2cfc0f46d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -400,14 +400,14 @@ class BuildExtension(build_ext, object):
# ncvv compile CUDA source
if is_cuda_file(src):
if core.is_compiled_with_rocm():
assert ROCM_HOME is not None
assert ROCM_HOME is not None, "Not found ROCM runtime, please use `export ROCM_PATH= XXX` to specific it."
hipcc_cmd = os.path.join(ROCM_HOME, 'bin', 'hipcc')
self.compiler.set_executable('compiler_so', hipcc_cmd)
# {'nvcc': {}, 'cxx: {}}
if isinstance(cflags, dict):
cflags = cflags['hipcc']
else:
assert CUDA_HOME is not None
assert CUDA_HOME is not None, "Not found CUDA runtime, please use `export CUDA_HOME= XXX` to specific it."
nvcc_cmd = os.path.join(CUDA_HOME, 'bin', 'nvcc')
self.compiler.set_executable('compiler_so', nvcc_cmd)
# {'nvcc': {}, 'cxx: {}}
@ -470,7 +470,7 @@ class BuildExtension(build_ext, object):
src = src_list[0]
obj = obj_list[0]
if is_cuda_file(src):
assert CUDA_HOME is not None
assert CUDA_HOME is not None, "Not found CUDA runtime, please use `export CUDA_HOME= XXX` to specific it."
nvcc_cmd = os.path.join(CUDA_HOME, 'bin', 'nvcc')
if isinstance(self.cflags, dict):
cflags = self.cflags['nvcc']

@ -461,9 +461,6 @@ def find_cuda_home():
if cuda_home and not os.path.exists(
cuda_home) and core.is_compiled_with_cuda():
cuda_home = None
warnings.warn(
"Not found CUDA runtime, please use `export CUDA_HOME= XXX` to specific it."
)
return cuda_home
@ -494,9 +491,6 @@ def find_rocm_home():
if rocm_home and not os.path.exists(
rocm_home) and core.is_compiled_with_rocm():
rocm_home = None
warnings.warn(
"Not found ROCM runtime, please use `export ROCM_PATH= XXX` to specific it."
)
return rocm_home

Loading…
Cancel
Save