Enable INT8 pool OP (#15046)
* Enable INT8 pool OP test=develop * fix unittest test=develop * Clean unittest code. test=developrevert-15207-remove_op_handle_lock_and_fix_var
parent
227e0c4518
commit
8eb1f26211
@ -0,0 +1,110 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
from __future__ import division
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
from op_test import OpTest
|
||||||
|
from test_pool2d_op import TestPool2D_Op, avg_pool2D_forward_naive, max_pool2D_forward_naive
|
||||||
|
|
||||||
|
|
||||||
|
class TestPool2dMKLDNNInt8_Op(TestPool2D_Op):
|
||||||
|
def init_kernel_type(self):
|
||||||
|
self.use_mkldnn = True
|
||||||
|
|
||||||
|
def init_data_type(self):
|
||||||
|
self.dtype = np.int8
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
TestPool2D_Op.setUp(self)
|
||||||
|
assert self.dtype in [np.int8, np.uint8
|
||||||
|
], 'Dtype should be int8 or uint8'
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output_with_place(core.CPUPlace(), atol=1e-5)
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestCase1Avg(TestPool2dMKLDNNInt8_Op):
|
||||||
|
def init_test_case(self):
|
||||||
|
self.shape = [2, 3, 7, 7]
|
||||||
|
self.ksize = [3, 3]
|
||||||
|
self.strides = [1, 1]
|
||||||
|
self.paddings = [0, 0]
|
||||||
|
|
||||||
|
def init_global_pool(self):
|
||||||
|
self.global_pool = False
|
||||||
|
|
||||||
|
|
||||||
|
class TestCase2Avg(TestPool2dMKLDNNInt8_Op):
|
||||||
|
def init_test_case(self):
|
||||||
|
self.shape = [2, 3, 7, 7]
|
||||||
|
self.ksize = [3, 3]
|
||||||
|
self.strides = [1, 1]
|
||||||
|
self.paddings = [1, 1]
|
||||||
|
|
||||||
|
def init_global_pool(self):
|
||||||
|
self.global_pool = False
|
||||||
|
|
||||||
|
|
||||||
|
class TestCase0Max(TestPool2dMKLDNNInt8_Op):
|
||||||
|
def init_pool_type(self):
|
||||||
|
self.pool_type = "max"
|
||||||
|
self.pool2D_forward_naive = max_pool2D_forward_naive
|
||||||
|
|
||||||
|
|
||||||
|
class TestCase1Max(TestCase1Avg):
|
||||||
|
def init_pool_type(self):
|
||||||
|
self.pool_type = "max"
|
||||||
|
self.pool2D_forward_naive = max_pool2D_forward_naive
|
||||||
|
|
||||||
|
|
||||||
|
class TestCase2Max(TestCase2Avg):
|
||||||
|
def init_pool_type(self):
|
||||||
|
self.pool_type = "max"
|
||||||
|
self.pool2D_forward_naive = max_pool2D_forward_naive
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_s8_u8_class(parent):
|
||||||
|
class TestS8Case(parent):
|
||||||
|
def init_data_type(self):
|
||||||
|
self.dtype = np.int8
|
||||||
|
|
||||||
|
class TestU8Case(parent):
|
||||||
|
def init_data_type(self):
|
||||||
|
self.dtype = np.uint8
|
||||||
|
|
||||||
|
cls_name_s8 = "{0}_{1}".format(parent.__name__, "mkldnn_s8")
|
||||||
|
cls_name_u8 = "{0}_{1}".format(parent.__name__, "mkldnn_u8")
|
||||||
|
TestS8Case.__name__ = cls_name_s8
|
||||||
|
TestU8Case.__name__ = cls_name_u8
|
||||||
|
globals()[cls_name_s8] = TestS8Case
|
||||||
|
globals()[cls_name_u8] = TestU8Case
|
||||||
|
|
||||||
|
|
||||||
|
create_test_s8_u8_class(TestPool2dMKLDNNInt8_Op)
|
||||||
|
create_test_s8_u8_class(TestCase1Avg)
|
||||||
|
create_test_s8_u8_class(TestCase2Avg)
|
||||||
|
create_test_s8_u8_class(TestCase0Max)
|
||||||
|
create_test_s8_u8_class(TestCase1Max)
|
||||||
|
create_test_s8_u8_class(TestCase2Max)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue