Fix the compile issue for cuda device (test=develop)

revert-14786-revert-14782-revert-14398-imperative
Yihua Xu 7 years ago
parent 669191c9cc
commit ac803fed18

@ -112,59 +112,35 @@ class TestConv3dOp(OpTest):
return core.is_compiled_with_cuda() and self.use_cudnn return core.is_compiled_with_cuda() and self.use_cudnn
def test_check_output(self): def test_check_output(self):
if self.testcudnn(): place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
place = core.CUDAPlace(0) self.check_output_with_place(place, atol=1e-5)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
if self.testcudnn(): place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
place = core.CUDAPlace(0) self.check_grad_with_place(
self.check_grad_with_place( place, {'Input', 'Filter'}, 'Output', max_relative_error=0.03)
place,
set(['Input', 'Filter']),
'Output',
max_relative_error=0.03)
else:
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.03)
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
if self.testcudnn(): place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
place = core.CUDAPlace(0) self.check_grad_with_place(
self.check_grad_with_place( place, ['Input'],
place, ['Input'], 'Output',
'Output', max_relative_error=0.03,
max_relative_error=0.03, no_grad_set=set(['Filter']))
no_grad_set=set(['Filter']))
else:
self.check_grad(
['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
if self.testcudnn(): place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
place = core.CUDAPlace(0) self.check_grad_with_place(
self.check_grad_with_place( place, ['Input'],
place, ['Filter'], 'Output',
'Output', max_relative_error=0.03,
max_relative_error=0.03, no_grad_set=set(['Input']))
no_grad_set=set(['Input']))
else:
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']))
def init_test_case(self): def init_test_case(self):
self.pad = [0, 0, 0] self.pad = [0, 0, 0]

Loading…
Cancel
Save