|
|
|
@ -16,7 +16,7 @@ from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import paddle.fluid.core as core
|
|
|
|
|
from paddle.fluid.tests.unittests.op_test import OpTest
|
|
|
|
|
from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2DOp
|
|
|
|
@ -28,6 +28,8 @@ def conv2d_forward_refer(input, filter, group, conv_param):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(not core.supports_bfloat16(),
|
|
|
|
|
"place does not support BF16 evaluation")
|
|
|
|
|
class TestConv2DInt8Op(TestConv2DOp):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "conv2d"
|
|
|
|
@ -289,43 +291,31 @@ def init_data_type_with_fusion(self, input_dt, fuse_activation, fuse_residual):
|
|
|
|
|
def create_test_int8_class(parent):
|
|
|
|
|
|
|
|
|
|
#--------------------test conv2d s8 in and u8 out--------------------
|
|
|
|
|
|
|
|
|
|
class TestS8U8Case(parent):
|
|
|
|
|
def init_data_type(self):
|
|
|
|
|
init_data_type_with_fusion(self, np.int8, "relu", False)
|
|
|
|
|
|
|
|
|
|
#--------------------test conv2d s8 in and s8 out--------------------
|
|
|
|
|
|
|
|
|
|
class TestS8S8Case(parent):
|
|
|
|
|
def init_data_type(self):
|
|
|
|
|
init_data_type_with_fusion(self, np.int8, "", False)
|
|
|
|
|
|
|
|
|
|
#--------------------test conv2d u8 in and s8 out--------------------
|
|
|
|
|
|
|
|
|
|
class TestU8S8Case(parent):
|
|
|
|
|
def init_data_type(self):
|
|
|
|
|
init_data_type_with_fusion(self, np.uint8, "", False)
|
|
|
|
|
|
|
|
|
|
#--------------------test conv2d u8 in and u8 out without residual fuse--------------------
|
|
|
|
|
|
|
|
|
|
class TestU8U8Case(parent):
|
|
|
|
|
def init_data_type(self):
|
|
|
|
|
init_data_type_with_fusion(self, np.uint8, "relu", False)
|
|
|
|
|
|
|
|
|
|
#--------------------test conv2d s8 in and u8 out with residual fuse--------------------
|
|
|
|
|
|
|
|
|
|
class TestS8U8ResCase(parent):
|
|
|
|
|
def init_data_type(self):
|
|
|
|
|
init_data_type_with_fusion(self, np.int8, "relu", True)
|
|
|
|
|
|
|
|
|
|
#--------------------test conv2d s8 in and s8 out with residual fuse--------------------
|
|
|
|
|
|
|
|
|
|
class TestS8S8ResCase(parent):
|
|
|
|
|
def init_data_type(self):
|
|
|
|
|
init_data_type_with_fusion(self, np.int8, "", True)
|
|
|
|
|
|
|
|
|
|
#--------------------test conv2d u8 in and s8 out with residual fuse--------------------
|
|
|
|
|
|
|
|
|
|
class TestU8S8ResCase(parent):
|
|
|
|
|
def init_data_type(self):
|
|
|
|
|
init_data_type_with_fusion(self, np.uint8, "", True)
|
|
|
|
@ -334,8 +324,7 @@ def create_test_int8_class(parent):
|
|
|
|
|
cls_name_s8s8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "0")
|
|
|
|
|
cls_name_u8s8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "0")
|
|
|
|
|
cls_name_u8u8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "1")
|
|
|
|
|
cls_name_s8u8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
|
|
|
|
|
"1", "1")
|
|
|
|
|
|
|
|
|
|
cls_name_s8s8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
|
|
|
|
|
"0", "1")
|
|
|
|
|
cls_name_u8s8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
|
|
|
|
@ -344,17 +333,27 @@ def create_test_int8_class(parent):
|
|
|
|
|
TestS8S8Case.__name__ = cls_name_s8s8
|
|
|
|
|
TestU8S8Case.__name__ = cls_name_u8s8
|
|
|
|
|
TestU8U8Case.__name__ = cls_name_u8u8
|
|
|
|
|
TestS8U8ResCase.__name__ = cls_name_s8u8_re_1
|
|
|
|
|
|
|
|
|
|
TestS8S8ResCase.__name__ = cls_name_s8s8_re_1
|
|
|
|
|
TestU8S8ResCase.__name__ = cls_name_u8s8_re_1
|
|
|
|
|
globals()[cls_name_s8u8] = TestS8U8Case
|
|
|
|
|
globals()[cls_name_s8s8] = TestS8S8Case
|
|
|
|
|
globals()[cls_name_u8s8] = TestU8S8Case
|
|
|
|
|
globals()[cls_name_u8u8] = TestU8U8Case
|
|
|
|
|
globals()[cls_name_s8u8_re_1] = TestS8U8ResCase
|
|
|
|
|
globals()[cls_name_s8s8_re_1] = TestS8S8ResCase
|
|
|
|
|
globals()[cls_name_u8s8_re_1] = TestU8S8ResCase
|
|
|
|
|
|
|
|
|
|
if os.name != 'nt':
|
|
|
|
|
#--------------------test conv2d s8 in and u8 out with residual fuse--------------------
|
|
|
|
|
class TestS8U8ResCase(parent):
|
|
|
|
|
def init_data_type(self):
|
|
|
|
|
init_data_type_with_fusion(self, np.int8, "relu", True)
|
|
|
|
|
|
|
|
|
|
cls_name_s8u8_re_1 = "{0}_relu_{1}_residual_{2}".format(parent.__name__,
|
|
|
|
|
"1", "1")
|
|
|
|
|
TestS8U8ResCase.__name__ = cls_name_s8u8_re_1
|
|
|
|
|
globals()[cls_name_s8u8_re_1] = TestS8U8ResCase
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
create_test_int8_class(TestConv2DInt8Op)
|
|
|
|
|
create_test_int8_class(TestWithPad)
|
|
|
|
@ -387,4 +386,6 @@ class TestConv2DOp_Valid_INT_MKLDNN(TestConv2DOp_AsyPadding_INT_MKLDNN):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
from paddle import enable_static
|
|
|
|
|
enable_static()
|
|
|
|
|
unittest.main()
|
|
|
|
|