fix unit test

revert-4814-Add_sequence_project_op
chengduoZH 8 years ago
parent b15c69f59d
commit 6f02fe7dfd

@ -130,8 +130,30 @@ class TestSeqProject(OpTest):
max_relative_error=0.05, max_relative_error=0.05,
no_grad_set=set(['X', 'PaddingData'])) no_grad_set=set(['X', 'PaddingData']))
def test_check_grad_input_filter(self):
self.check_grad(
['X', 'Filter'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['PaddingData']))
def test_check_grad_padding_input(self):
if self.padding_trainable:
self.check_grad(
['X', 'PaddingData'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['Filter']))
def test_check_grad_padding_filter(self):
if self.padding_trainable:
self.check_grad(
['PaddingData', 'Filter'],
'Out',
max_relative_error=0.05,
no_grad_set=set(['X']))
def init_test_case(self): def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 11 self.input_row = 11
self.context_start = 0 self.context_start = 0
self.context_length = 1 self.context_length = 1
@ -144,7 +166,6 @@ class TestSeqProject(OpTest):
class TestSeqProjectCase1(TestSeqProject): class TestSeqProjectCase1(TestSeqProject):
def init_test_case(self): def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 11 self.input_row = 11
self.context_start = -1 self.context_start = -1
self.context_length = 3 self.context_length = 3
@ -157,7 +178,6 @@ class TestSeqProjectCase1(TestSeqProject):
class TestSeqProjectCase2(TestSeqProject): class TestSeqProjectCase2(TestSeqProject):
def init_test_case(self): def init_test_case(self):
self.op_type = "sequence_project"
self.input_row = 25 self.input_row = 25
self.context_start = 2 self.context_start = 2
self.context_length = 3 self.context_length = 3

Loading…
Cancel
Save