|
|
|
@ -151,19 +151,13 @@ def create_test_cudnn_fp16_class(parent, grad_check=True):
|
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
|
if core.is_float16_supported(place) and grad_check:
|
|
|
|
|
self.check_grad_with_place(
|
|
|
|
|
place, ['Input'],
|
|
|
|
|
'Output',
|
|
|
|
|
max_relative_error=0.02,
|
|
|
|
|
no_grad_set=set(['Filter']))
|
|
|
|
|
place, ['Input'], 'Output', no_grad_set=set(['Filter']))
|
|
|
|
|
|
|
|
|
|
def test_check_grad_no_input(self):
|
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
|
if core.is_float16_supported(place) and grad_check:
|
|
|
|
|
self.check_grad_with_place(
|
|
|
|
|
place, ['Filter'],
|
|
|
|
|
'Output',
|
|
|
|
|
max_relative_error=0.02,
|
|
|
|
|
no_grad_set=set(['Input']))
|
|
|
|
|
place, ['Filter'], 'Output', no_grad_set=set(['Input']))
|
|
|
|
|
|
|
|
|
|
cls_name = "{0}_{1}".format(parent.__name__, "CUDNNFp16")
|
|
|
|
|
TestConv2DCUDNNFp16.__name__ = cls_name
|
|
|
|
@ -221,19 +215,13 @@ def create_test_cudnn_channel_last_fp16_class(parent, grad_check=True):
|
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
|
if core.is_float16_supported(place) and grad_check:
|
|
|
|
|
self.check_grad_with_place(
|
|
|
|
|
place, ['Input'],
|
|
|
|
|
'Output',
|
|
|
|
|
max_relative_error=0.02,
|
|
|
|
|
no_grad_set=set(['Filter']))
|
|
|
|
|
place, ['Input'], 'Output', no_grad_set=set(['Filter']))
|
|
|
|
|
|
|
|
|
|
def test_check_grad_no_input(self):
|
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
|
if core.is_float16_supported(place) and grad_check:
|
|
|
|
|
self.check_grad_with_place(
|
|
|
|
|
place, ['Filter'],
|
|
|
|
|
'Output',
|
|
|
|
|
max_relative_error=0.02,
|
|
|
|
|
no_grad_set=set(['Input']))
|
|
|
|
|
place, ['Filter'], 'Output', no_grad_set=set(['Input']))
|
|
|
|
|
|
|
|
|
|
def init_data_format(self):
|
|
|
|
|
self.data_format = "NHWC"
|
|
|
|
@ -397,7 +385,6 @@ class TestConv2dOp(OpTest):
|
|
|
|
|
self.check_grad_with_place(
|
|
|
|
|
place, ['Filter'],
|
|
|
|
|
'Output',
|
|
|
|
|
max_relative_error=0.02,
|
|
|
|
|
no_grad_set=set(['Input']),
|
|
|
|
|
check_dygraph=(self.use_mkldnn == False))
|
|
|
|
|
|
|
|
|
@ -827,7 +814,6 @@ class TestConv2dOp_v2(OpTest):
|
|
|
|
|
self.check_grad_with_place(
|
|
|
|
|
place, ['Filter'],
|
|
|
|
|
'Output',
|
|
|
|
|
max_relative_error=0.02,
|
|
|
|
|
no_grad_set=set(['Input']),
|
|
|
|
|
check_dygraph=(self.use_mkldnn == False))
|
|
|
|
|
|
|
|
|
|