|
|
|
@ -35,6 +35,8 @@ def conv2d_residual_naive(out, residual):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(not core.supports_bfloat16(),
|
|
|
|
|
"place does not support BF16 evaluation")
|
|
|
|
|
class TestConv2dBf16Op(TestConv2dOp):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "conv2d"
|
|
|
|
@ -42,9 +44,9 @@ class TestConv2dBf16Op(TestConv2dOp):
|
|
|
|
|
self.exhaustive_search = False
|
|
|
|
|
self.use_cuda = False
|
|
|
|
|
self.use_mkldnn = True
|
|
|
|
|
self._cpu_only = True
|
|
|
|
|
self.weight_type = np.float32
|
|
|
|
|
self.input_type = np.float32
|
|
|
|
|
self.use_mkldnn = True
|
|
|
|
|
self.mkldnn_data_type = "bfloat16"
|
|
|
|
|
self.force_fp32_output = False
|
|
|
|
|
self.init_group()
|
|
|
|
@ -205,5 +207,4 @@ class TestWithInput1x1Filter1x1(TestConv2dBf16Op):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
if core.supports_bfloat16():
|
|
|
|
|
unittest.main()
|
|
|
|
|
unittest.main()
|
|
|
|
|