|
|
|
@ -73,13 +73,22 @@ class TestConv2dOp(OpTest):
|
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
def test_check_grad(self):
|
|
|
|
|
self.check_grad(set(['Input', 'Filter']), 'Output')
|
|
|
|
|
self.check_grad(
|
|
|
|
|
set(['Input', 'Filter']), 'Output', max_relative_error=0.05)
|
|
|
|
|
|
|
|
|
|
def test_check_grad_no_filter(self):
|
|
|
|
|
self.check_grad(['Input'], 'Output', no_grad_set=set(['Filter']))
|
|
|
|
|
self.check_grad(
|
|
|
|
|
['Input'],
|
|
|
|
|
'Output',
|
|
|
|
|
max_relative_error=0.05,
|
|
|
|
|
no_grad_set=set(['Filter']))
|
|
|
|
|
|
|
|
|
|
def test_check_grad_no_input(self):
|
|
|
|
|
self.check_grad(['Filter'], 'Output', no_grad_set=set(['Input']))
|
|
|
|
|
self.check_grad(
|
|
|
|
|
['Filter'],
|
|
|
|
|
'Output',
|
|
|
|
|
max_relative_error=0.05,
|
|
|
|
|
no_grad_set=set(['Input']))
|
|
|
|
|
|
|
|
|
|
def init_groups(self):
|
|
|
|
|
self.groups = 1
|
|
|
|
|