From 8f051b36d542663903f98f8aa4c53187545111bf Mon Sep 17 00:00:00 2001 From: "xiaoli.liu@intel.com" Date: Tue, 25 Dec 2018 17:40:24 +0800 Subject: [PATCH 1/4] Enable INT8 pool OP test=develop --- paddle/fluid/operators/pool_mkldnn_op.cc | 31 ++- .../unittests/test_pool2d_int8_mkldnn_op.py | 236 ++++++++++++++++++ 2 files changed, 256 insertions(+), 11 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py diff --git a/paddle/fluid/operators/pool_mkldnn_op.cc b/paddle/fluid/operators/pool_mkldnn_op.cc index 0a9a29956a..f6f40b1daf 100644 --- a/paddle/fluid/operators/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/pool_mkldnn_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/operators/pool_op.h" #include "paddle/fluid/platform/mkldnn_helper.h" @@ -71,7 +72,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { void Compute(const paddle::framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); - auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); @@ -130,20 +130,25 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides, padding_right_bottom); } - auto src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), input_format); + + mkldnn::memory::data_type dt = + paddle::framework::ToMKLDNNDataType(input->type()); + + auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format); /* create memory descriptor for pooling without specified format * ('any') which lets a primitive (pooling in this case) choose * the memory format preferred for best performance */ - auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32, - mkldnn::memory::format::any); - + auto dst_md = + platform::MKLDNNMemDesc(dst_tz, dt, mkldnn::memory::format::any); + auto propagation = src_md.data.data_type == mkldnn_f32 + ? mkldnn::prop_kind::forward_training + : mkldnn::prop_kind::forward_scoring; std::shared_ptr pool_pd = - CreatePrimitiveDesc(src_md, dst_md, strides, padding_left_top, - padding_right_bottom, ksize, pooling_type, - mkldnn_engine, ceil_mode, is_test); + CreatePrimitiveDesc(src_md, dst_md, propagation, strides, + padding_left_top, padding_right_bottom, ksize, + pooling_type, mkldnn_engine, ceil_mode, is_test); // save pool_pd into global device context to be referred in backward path if (!is_test) dev_ctx.SetBlob(key_pool_pd, pool_pd); @@ -203,7 +208,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { private: std::unique_ptr CreatePrimitiveDesc( const mkldnn::memory::desc& src, const mkldnn::memory::desc& dst, - const std::vector& stride, const std::vector& padding_left_top, + const mkldnn::prop_kind& propagation, const std::vector& stride, + const std::vector& padding_left_top, const std::vector& padding_right_bot, const std::vector& kernel, const std::string& pooling_type, const mkldnn::engine& engine, bool ceil_mode, bool is_test) const { @@ -411,6 +417,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace, - ops::PoolMKLDNNOpKernel); + ops::PoolMKLDNNOpKernel, + ops::PoolMKLDNNOpKernel, + ops::PoolMKLDNNOpKernel); + REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace, ops::PoolMKLDNNGradOpKernel); diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py new file mode 100644 index 0000000000..954d9993b2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py @@ -0,0 +1,236 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from __future__ import division + +import unittest +import numpy as np + +import paddle.fluid.core as core +from op_test import OpTest + + +def adaptive_start_index(index, input_size, output_size): + return int(np.floor(index * input_size / output_size)) + + +def adaptive_end_index(index, input_size, output_size): + return int(np.ceil((index + 1) * input_size / output_size)) + + +def max_pool2D_forward_naive(x, + ksize, + strides, + paddings, + global_pool=0, + ceil_mode=False, + exclusive=True, + adaptive=False): + N, C, H, W = x.shape + if global_pool == 1: + ksize = [H, W] + if adaptive: + H_out, W_out = ksize + else: + H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1 + ) // strides[0] + 1 if ceil_mode else ( + H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1 + ) // strides[1] + 1 if ceil_mode else ( + W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 + out = np.zeros((N, C, H_out, W_out)) + for i in range(H_out): + for j in range(W_out): + if adaptive: + r_start = adaptive_start_index(i, H, ksize[0]) + r_end = adaptive_end_index(i, H, ksize[0]) + c_start = adaptive_start_index(j, W, ksize[1]) + c_end = adaptive_end_index(j, W, ksize[1]) + else: + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, r_start:r_end, c_start:c_end] + + out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) + return out + + +def avg_pool2D_forward_naive(x, + ksize, + strides, + paddings, + global_pool=0, + ceil_mode=False, + exclusive=True, + adaptive=False): + N, C, H, W = x.shape + if global_pool == 1: + ksize = [H, W] + if adaptive: + H_out, W_out = ksize + else: + H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1 + ) // strides[0] + 1 if ceil_mode else ( + H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1 + ) // strides[1] + 1 if ceil_mode else ( + W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 + out = np.zeros((N, C, H_out, W_out)) + for i in range(H_out): + for j in range(W_out): + if adaptive: + r_start = adaptive_start_index(i, H, ksize[0]) + r_end = adaptive_end_index(i, H, ksize[0]) + c_start = adaptive_start_index(j, W, ksize[1]) + c_end = adaptive_end_index(j, W, ksize[1]) + else: + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, r_start:r_end, c_start:c_end] + + field_size = ((r_end - r_start) * (c_end - c_start)) \ + if (exclusive or adaptive) else (ksize[0] * ksize[1]) + out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size + return out + + +class TestPool2D_Op(OpTest): + def setUp(self): + self.op_type = "pool2d" + self.use_cudnn = False + self.use_mkldnn = True + self.dtype = np.int8 + self.init_test_case() + self.init_global_pool() + self.init_pool_type() + self.init_ceil_mode() + self.init_exclusive() + self.init_adaptive() + if self.global_pool: + self.paddings = [0 for _ in range(len(self.paddings))] + input = np.random.random(self.shape).astype(self.dtype) + output = self.pool2D_forward_naive( + input, self.ksize, self.strides, self.paddings, self.global_pool, + self.ceil_mode, self.exclusive, self.adaptive).astype(self.dtype) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)} + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'pooling_type': self.pool_type, + 'global_pooling': self.global_pool, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + 'ceil_mode': self.ceil_mode, + 'data_format': + 'AnyLayout', # TODO(dzhwinter) : should be fix latter + 'exclusive': self.exclusive, + 'adaptive': self.adaptive + } + + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def init_test_case(self): + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + self.dtype = np.int8 + + def init_pool_type(self): + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = True + + def init_ceil_mode(self): + self.ceil_mode = False + + def init_exclusive(self): + self.exclusive = True + + def init_adaptive(self): + self.adaptive = False + + +class TestCase1(TestPool2D_Op): + def init_test_case(self): + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + self.dtype = np.int8 + + def init_pool_type(self): + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = False + + +class TestCase2(TestPool2D_Op): + def init_test_case(self): + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + self.dtype = np.uint8 + + def init_pool_type(self): + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = False + + +class TestCase3(TestPool2D_Op): + def init_test_case(self): + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + self.dtype = np.int8 + + def init_pool_type(self): + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + + +class TestCase4(TestCase1): + def init_test_case(self): + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + self.dtype = np.uint8 + + def init_pool_type(self): + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + + +if __name__ == '__main__': + unittest.main() From 157e79e8ecdb22c7aeda84cc7ef80bde63ecde0e Mon Sep 17 00:00:00 2001 From: "xiaoli.liu@intel.com" Date: Fri, 28 Dec 2018 00:54:01 +0800 Subject: [PATCH 2/4] fix unittest test=develop --- .../paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py index 954d9993b2..e73ac7c0aa 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py @@ -148,7 +148,7 @@ class TestPool2D_Op(OpTest): self.outputs = {'Out': output} def test_check_output(self): - self.check_output() + self.check_output_with_place(core.CPUPlace(), atol=1e-5) def init_test_case(self): self.shape = [2, 3, 5, 5] From 60eaf967eb6fa5273e268a72dc2c260ae3d348aa Mon Sep 17 00:00:00 2001 From: "xiaoli.liu@intel.com" Date: Sat, 29 Dec 2018 20:15:00 +0800 Subject: [PATCH 3/4] Clean unittest code. test=develop --- .../unittests/test_pool2d_int8_mkldnn_op.py | 216 ++++-------------- .../tests/unittests/test_pool2d_mkldnn_op.py | 45 ++-- .../fluid/tests/unittests/test_pool2d_op.py | 5 +- 3 files changed, 65 insertions(+), 201 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py index e73ac7c0aa..f4495d0bc8 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_int8_mkldnn_op.py @@ -20,217 +20,91 @@ import numpy as np import paddle.fluid.core as core from op_test import OpTest +from test_pool2d_op import TestPool2D_Op, avg_pool2D_forward_naive, max_pool2D_forward_naive -def adaptive_start_index(index, input_size, output_size): - return int(np.floor(index * input_size / output_size)) - - -def adaptive_end_index(index, input_size, output_size): - return int(np.ceil((index + 1) * input_size / output_size)) - - -def max_pool2D_forward_naive(x, - ksize, - strides, - paddings, - global_pool=0, - ceil_mode=False, - exclusive=True, - adaptive=False): - N, C, H, W = x.shape - if global_pool == 1: - ksize = [H, W] - if adaptive: - H_out, W_out = ksize - else: - H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1 - ) // strides[0] + 1 if ceil_mode else ( - H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 - W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1 - ) // strides[1] + 1 if ceil_mode else ( - W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 - out = np.zeros((N, C, H_out, W_out)) - for i in range(H_out): - for j in range(W_out): - if adaptive: - r_start = adaptive_start_index(i, H, ksize[0]) - r_end = adaptive_end_index(i, H, ksize[0]) - c_start = adaptive_start_index(j, W, ksize[1]) - c_end = adaptive_end_index(j, W, ksize[1]) - else: - r_start = np.max((i * strides[0] - paddings[0], 0)) - r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) - c_start = np.max((j * strides[1] - paddings[1], 0)) - c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) - x_masked = x[:, :, r_start:r_end, c_start:c_end] - - out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) - return out - - -def avg_pool2D_forward_naive(x, - ksize, - strides, - paddings, - global_pool=0, - ceil_mode=False, - exclusive=True, - adaptive=False): - N, C, H, W = x.shape - if global_pool == 1: - ksize = [H, W] - if adaptive: - H_out, W_out = ksize - else: - H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1 - ) // strides[0] + 1 if ceil_mode else ( - H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 - W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1 - ) // strides[1] + 1 if ceil_mode else ( - W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 - out = np.zeros((N, C, H_out, W_out)) - for i in range(H_out): - for j in range(W_out): - if adaptive: - r_start = adaptive_start_index(i, H, ksize[0]) - r_end = adaptive_end_index(i, H, ksize[0]) - c_start = adaptive_start_index(j, W, ksize[1]) - c_end = adaptive_end_index(j, W, ksize[1]) - else: - r_start = np.max((i * strides[0] - paddings[0], 0)) - r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) - c_start = np.max((j * strides[1] - paddings[1], 0)) - c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) - x_masked = x[:, :, r_start:r_end, c_start:c_end] - - field_size = ((r_end - r_start) * (c_end - c_start)) \ - if (exclusive or adaptive) else (ksize[0] * ksize[1]) - out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size - return out - - -class TestPool2D_Op(OpTest): - def setUp(self): - self.op_type = "pool2d" - self.use_cudnn = False +class TestPool2dMKLDNNInt8_Op(TestPool2D_Op): + def init_kernel_type(self): self.use_mkldnn = True - self.dtype = np.int8 - self.init_test_case() - self.init_global_pool() - self.init_pool_type() - self.init_ceil_mode() - self.init_exclusive() - self.init_adaptive() - if self.global_pool: - self.paddings = [0 for _ in range(len(self.paddings))] - input = np.random.random(self.shape).astype(self.dtype) - output = self.pool2D_forward_naive( - input, self.ksize, self.strides, self.paddings, self.global_pool, - self.ceil_mode, self.exclusive, self.adaptive).astype(self.dtype) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)} - - self.attrs = { - 'strides': self.strides, - 'paddings': self.paddings, - 'ksize': self.ksize, - 'pooling_type': self.pool_type, - 'global_pooling': self.global_pool, - 'use_cudnn': self.use_cudnn, - 'use_mkldnn': self.use_mkldnn, - 'ceil_mode': self.ceil_mode, - 'data_format': - 'AnyLayout', # TODO(dzhwinter) : should be fix latter - 'exclusive': self.exclusive, - 'adaptive': self.adaptive - } - - self.outputs = {'Out': output} - - def test_check_output(self): - self.check_output_with_place(core.CPUPlace(), atol=1e-5) - def init_test_case(self): - self.shape = [2, 3, 5, 5] - self.ksize = [3, 3] - self.strides = [1, 1] - self.paddings = [0, 0] + def init_data_type(self): self.dtype = np.int8 - def init_pool_type(self): - self.pool_type = "avg" - self.pool2D_forward_naive = avg_pool2D_forward_naive - - def init_global_pool(self): - self.global_pool = True - - def init_ceil_mode(self): - self.ceil_mode = False + def setUp(self): + TestPool2D_Op.setUp(self) + assert self.dtype in [np.int8, np.uint8 + ], 'Dtype should be int8 or uint8' - def init_exclusive(self): - self.exclusive = True + def test_check_output(self): + self.check_output_with_place(core.CPUPlace(), atol=1e-5) - def init_adaptive(self): - self.adaptive = False + def test_check_grad(self): + pass -class TestCase1(TestPool2D_Op): +class TestCase1Avg(TestPool2dMKLDNNInt8_Op): def init_test_case(self): self.shape = [2, 3, 7, 7] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [0, 0] - self.dtype = np.int8 - - def init_pool_type(self): - self.pool_type = "avg" - self.pool2D_forward_naive = avg_pool2D_forward_naive def init_global_pool(self): self.global_pool = False -class TestCase2(TestPool2D_Op): +class TestCase2Avg(TestPool2dMKLDNNInt8_Op): def init_test_case(self): self.shape = [2, 3, 7, 7] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [1, 1] - self.dtype = np.uint8 - - def init_pool_type(self): - self.pool_type = "avg" - self.pool2D_forward_naive = avg_pool2D_forward_naive def init_global_pool(self): self.global_pool = False -class TestCase3(TestPool2D_Op): - def init_test_case(self): - self.shape = [2, 3, 7, 7] - self.ksize = [3, 3] - self.strides = [1, 1] - self.paddings = [0, 0] - self.dtype = np.int8 - +class TestCase0Max(TestPool2dMKLDNNInt8_Op): def init_pool_type(self): self.pool_type = "max" self.pool2D_forward_naive = max_pool2D_forward_naive -class TestCase4(TestCase1): - def init_test_case(self): - self.shape = [2, 3, 7, 7] - self.ksize = [3, 3] - self.strides = [1, 1] - self.paddings = [1, 1] - self.dtype = np.uint8 +class TestCase1Max(TestCase1Avg): + def init_pool_type(self): + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + +class TestCase2Max(TestCase2Avg): def init_pool_type(self): self.pool_type = "max" self.pool2D_forward_naive = max_pool2D_forward_naive +def create_test_s8_u8_class(parent): + class TestS8Case(parent): + def init_data_type(self): + self.dtype = np.int8 + + class TestU8Case(parent): + def init_data_type(self): + self.dtype = np.uint8 + + cls_name_s8 = "{0}_{1}".format(parent.__name__, "mkldnn_s8") + cls_name_u8 = "{0}_{1}".format(parent.__name__, "mkldnn_u8") + TestS8Case.__name__ = cls_name_s8 + TestU8Case.__name__ = cls_name_u8 + globals()[cls_name_s8] = TestS8Case + globals()[cls_name_u8] = TestU8Case + + +create_test_s8_u8_class(TestPool2dMKLDNNInt8_Op) +create_test_s8_u8_class(TestCase1Avg) +create_test_s8_u8_class(TestCase2Avg) +create_test_s8_u8_class(TestCase0Max) +create_test_s8_u8_class(TestCase1Max) +create_test_s8_u8_class(TestCase2Max) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py index 19f29c7826..7de5fefc14 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py @@ -18,35 +18,22 @@ import unittest from test_pool2d_op import TestPool2D_Op, TestCase1, TestCase2, TestCase3, TestCase4, TestCase5 -class TestMKLDNNCase1(TestPool2D_Op): - def init_kernel_type(self): - self.use_mkldnn = True - - -class TestMKLDNNCase2(TestCase1): - def init_kernel_type(self): - self.use_mkldnn = True - - -class TestMKLDNNCase3(TestCase2): - def init_kernel_type(self): - self.use_mkldnn = True - - -class TestMKLDNNCase4(TestCase3): - def init_kernel_type(self): - self.use_mkldnn = True - - -class TestMKLDNNCase5(TestCase4): - def init_kernel_type(self): - self.use_mkldnn = True - - -class TestMKLDNNCase6(TestCase5): - def init_kernel_type(self): - self.use_mkldnn = True - +def create_test_mkldnn_class(parent): + class TestMKLDNNCase(parent): + def init_kernel_type(self): + self.use_mkldnn = True + + cls_name = "{0}_{1}".format(parent.__name__, "MKLDNNOp") + TestMKLDNNCase.__name__ = cls_name + globals()[cls_name] = TestMKLDNNCase + + +create_test_mkldnn_class(TestPool2D_Op) +create_test_mkldnn_class(TestCase1) +create_test_mkldnn_class(TestCase2) +create_test_mkldnn_class(TestCase3) +create_test_mkldnn_class(TestCase4) +create_test_mkldnn_class(TestCase5) if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_op.py index 5ccdf082e8..92515add59 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_op.py @@ -115,7 +115,7 @@ class TestPool2D_Op(OpTest): self.op_type = "pool2d" self.use_cudnn = False self.use_mkldnn = False - self.dtype = np.float32 + self.init_data_type() self.init_test_case() self.init_global_pool() self.init_kernel_type() @@ -177,6 +177,9 @@ class TestPool2D_Op(OpTest): def init_kernel_type(self): pass + def init_data_type(self): + self.dtype = np.float32 + def init_pool_type(self): self.pool_type = "avg" self.pool2D_forward_naive = avg_pool2D_forward_naive From f34e779f4dc152efbecdedcdd561fa062aa79110 Mon Sep 17 00:00:00 2001 From: "xiaoli.liu@intel.com" Date: Thu, 10 Jan 2019 17:17:33 +0800 Subject: [PATCH 4/4] Enhance key generation for INT8 test. test=develop --- paddle/fluid/operators/pool_mkldnn_op.cc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/pool_mkldnn_op.cc b/paddle/fluid/operators/pool_mkldnn_op.cc index f6f40b1daf..f4bad7b712 100644 --- a/paddle/fluid/operators/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/pool_mkldnn_op.cc @@ -35,6 +35,7 @@ static std::string gethash(const memory::dims& input_dims, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, + const memory::data_type& dt, const std::string& suffix) { auto dims2str = [](const memory::dims& operand_dims) { std::string dstr = ""; @@ -44,7 +45,7 @@ static std::string gethash(const memory::dims& input_dims, return dstr; }; return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) + - dims2str(paddings) + pooling_type + suffix; + dims2str(paddings) + std::to_string(dt) + pooling_type + suffix; } static inline int ComputeCeiledOutput(int input_size, int kernel_size, @@ -111,8 +112,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { auto input_format = input->format(); memory::format output_format{memory::format::format_undef}; + mkldnn::memory::data_type dt = + paddle::framework::ToMKLDNNDataType(input->type()); const std::string key = gethash(src_tz, pooling_type, ksize, strides, - paddings, ctx.op().Output("Out")); + paddings, dt, ctx.op().Output("Out")); const std::string key_pool_p = key + "@pool_p"; const std::string key_pool_pd = key + "@pool_pd"; const std::string key_pool_src_mem_p = key + "@pool_src_mem_p"; @@ -131,9 +134,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { padding_right_bottom); } - mkldnn::memory::data_type dt = - paddle::framework::ToMKLDNNDataType(input->type()); - auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format); /* create memory descriptor for pooling without specified format @@ -293,8 +293,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { // Get an unique name from "argument" name of "Out" variable // This name will be used as key when referring info from device context - const std::string key = gethash(diff_src_tz, pooling_type, ksize, strides, - paddings, ctx.op().Input("Out")); + const std::string key = + gethash(diff_src_tz, pooling_type, ksize, strides, paddings, + memory::data_type::f32, ctx.op().Input("Out")); const std::string key_pool_bwd_p = key + "@pool_bwd_p"; const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p"; const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p";