Clean unittest code.

test=develop
recover_files
xiaoli.liu@intel.com 6 years ago
parent 157e79e8ec
commit 60eaf967eb

@ -18,35 +18,22 @@ import unittest
from test_pool2d_op import TestPool2D_Op, TestCase1, TestCase2, TestCase3, TestCase4, TestCase5
class TestMKLDNNCase1(TestPool2D_Op):
def init_kernel_type(self):
self.use_mkldnn = True
class TestMKLDNNCase2(TestCase1):
def init_kernel_type(self):
self.use_mkldnn = True
class TestMKLDNNCase3(TestCase2):
def init_kernel_type(self):
self.use_mkldnn = True
class TestMKLDNNCase4(TestCase3):
def init_kernel_type(self):
self.use_mkldnn = True
class TestMKLDNNCase5(TestCase4):
def init_kernel_type(self):
self.use_mkldnn = True
class TestMKLDNNCase6(TestCase5):
def init_kernel_type(self):
self.use_mkldnn = True
def create_test_mkldnn_class(parent):
class TestMKLDNNCase(parent):
def init_kernel_type(self):
self.use_mkldnn = True
cls_name = "{0}_{1}".format(parent.__name__, "MKLDNNOp")
TestMKLDNNCase.__name__ = cls_name
globals()[cls_name] = TestMKLDNNCase
create_test_mkldnn_class(TestPool2D_Op)
create_test_mkldnn_class(TestCase1)
create_test_mkldnn_class(TestCase2)
create_test_mkldnn_class(TestCase3)
create_test_mkldnn_class(TestCase4)
create_test_mkldnn_class(TestCase5)
if __name__ == '__main__':
unittest.main()

@ -115,7 +115,7 @@ class TestPool2D_Op(OpTest):
self.op_type = "pool2d"
self.use_cudnn = False
self.use_mkldnn = False
self.dtype = np.float32
self.init_data_type()
self.init_test_case()
self.init_global_pool()
self.init_kernel_type()
@ -177,6 +177,9 @@ class TestPool2D_Op(OpTest):
def init_kernel_type(self):
pass
def init_data_type(self):
self.dtype = np.float32
def init_pool_type(self):
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive

Loading…
Cancel
Save