Enable test conv2d ngraph (#22074)

release/1.7
Leo Zhao 5 years ago committed by Tao Luo
parent c112b645c4
commit 1c39efb783

@ -1,8 +1,6 @@
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
list(REMOVE_ITEM TEST_OPS test_conv2d_ngraph_op)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS FLAGS_use_ngraph=true)
endforeach(TEST_OP)

@ -20,6 +20,13 @@ from test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride, TestWithGr
import numpy as np
class TestNGRAPHWithStride(TestWithStride):
def init_test_case(self):
super(TestNGRAPHWithStride, self).init_test_case()
self.use_cuda = False
self.dtype = np.float32
class TestNGRAPHDepthwiseConv(TestDepthwiseConv):
def init_test_case(self):
super(TestNGRAPHDepthwiseConv, self).init_test_case()
@ -55,7 +62,7 @@ class TestNGRAPHDepthwiseConvWithDilation2(TestDepthwiseConvWithDilation2):
self.dtype = np.float32
del TestDepthwiseConv, TestDepthwiseConv2, TestDepthwiseConv3, TestDepthwiseConvWithDilation, TestDepthwiseConvWithDilation2
del TestWithStride, TestDepthwiseConv, TestDepthwiseConv2, TestDepthwiseConv3, TestDepthwiseConvWithDilation, TestDepthwiseConvWithDilation2
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save