fix unit test

revert-4814-Add_sequence_project_op
chengduoZH 7 years ago
parent b15c69f59d
commit 6f02fe7dfd

@ -130,8 +130,30 @@ class TestSeqProject(OpTest):
max_relative_error=0.05,
no_grad_set=set(['X', 'PaddingData']))
def test_check_grad_input_filter(self):
self.check_grad(
['X', 'Filter'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['PaddingData']))
def test_check_grad_padding_input(self):
if self.padding_trainable:
self.check_grad(
['X', 'PaddingData'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['Filter']))
def test_check_grad_padding_filter(self):
if self.padding_trainable:
self.check_grad(
['PaddingData', 'Filter'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['X']))
def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 11
self.context_start = 0
self.context_length = 1
@ -144,7 +166,6 @@ class TestSeqProject(OpTest):
class TestSeqProjectCase1(TestSeqProject):
def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 11
self.context_start = -1
self.context_length = 3
@ -157,7 +178,6 @@ class TestSeqProjectCase1(TestSeqProject):
class TestSeqProjectCase2(TestSeqProject):
def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 25
self.context_start = 2
self.context_length = 3

Loading…
Cancel
Save