From 213ec37d6ad84c3774f1a5e203566dc47a1b63da Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Thu, 25 Oct 2018 16:18:04 +0200 Subject: [PATCH 01/15] MKLDNN elementwise_add: simple initial implementation of the operator for MKLDNN format --- .../operators/elementwise_mul_mkldnn_op.cc | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 paddle/fluid/operators/elementwise_mul_mkldnn_op.cc diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc new file mode 100644 index 0000000000..22289ab417 --- /dev/null +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -0,0 +1,99 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/elementwise_op.h" +#include "paddle/fluid/operators/elementwise_op_function.h" + +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using framework::DataLayout; + +template +class ElementwiseMulMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using Tensor = framework::Tensor; + + int axis = ctx.Attr("axis"); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + const T* x_data = x->data(); + const T* y_data = y->data(); + T* z_data = z->mutable_data(ctx.GetPlace()); + + auto x_dims = x->dims(); + auto y_dims_untrimmed = y->dims(); + + if (x_dims != y_dims_untrimmed) { + int pre, n, post; + get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); + + if (post == 1) { + PADDLE_THROW("Not implemented when post is 1"); + } else { + // Just check whether it works for RE-Resnext. + + PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); + + int n = x_dims[0]; + int c = x_dims[1]; + int h = x_dims[2]; + int w = x_dims[3]; + + PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, + "Y should be in nc format"); + + constexpr int simd_width = 16; + int C = c / simd_width; + + for (int ni = 0; ni < n; ni++) { + for (int ci = 0; ci < C; ci++) { + for (int hi = 0; hi < h; hi++) { + for (int wi = 0; wi < w; wi++) { + auto ptr_x = x_data + ni * C * h * w * simd_width + + ci * h * w * simd_width + hi * w * simd_width + + wi * simd_width; + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + + auto ptr_z = z_data + ni * C * h * w * simd_width + + ci * h * w * simd_width + hi * w * simd_width + + wi * simd_width; + + for (int i = 0; i < simd_width; i++) { + ptr_z[i] = ptr_x[i] * ptr_y[i]; + } + } + } + } + } + } + + z->set_layout(DataLayout::kMKLDNN); + z->set_format(x->format()); + } else { + PADDLE_THROW("Not implemented when dims are equal"); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(elementwise_mul, MKLDNN, ::paddle::platform::CPUPlace, + ops::ElementwiseMulMKLDNNKernel) From 2d73ad180ae80d1da4ae319106a22f8a11c79da9 Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Thu, 25 Oct 2018 17:07:17 +0200 Subject: [PATCH 02/15] MKLDNN elementwise_mul: simple xbyak version for AVX512 --- .../operators/elementwise_mul_mkldnn_op.cc | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc index 22289ab417..595a6232da 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -17,11 +17,29 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_helper.h" +#include "xbyak/xbyak.h" +#include "xbyak/xbyak_util.h" + namespace paddle { namespace operators { using framework::DataLayout; +struct vector_mul : public Xbyak::CodeGenerator { + vector_mul() { + // RDI is ptr X + // RSI is ptr Y + // RDX is ptr Z + + vmovups(zmm2, ptr[rdi]); + vmovups(zmm3, ptr[rsi]); + vmulps(zmm1, zmm2, zmm3); + vmovups(ptr[rdx], zmm1); + + ret(); + } +}; + template class ElementwiseMulMKLDNNKernel : public framework::OpKernel { public: @@ -61,6 +79,14 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { constexpr int simd_width = 16; int C = c / simd_width; + vector_mul mul; + + using mul_func_t = void (*)(const float*, const float*, float*); + + mul_func_t mul_func = (mul_func_t)mul.getCode(); + + auto ptr_x = x_data; + for (int ni = 0; ni < n; ni++) { for (int ci = 0; ci < C; ci++) { for (int hi = 0; hi < h; hi++) { @@ -74,9 +100,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { ci * h * w * simd_width + hi * w * simd_width + wi * simd_width; - for (int i = 0; i < simd_width; i++) { - ptr_z[i] = ptr_x[i] * ptr_y[i]; - } + mul_func(ptr_x, ptr_y, ptr_z); } } } From ad09facafecfd7157ea18d3b433c15135d914978 Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Fri, 26 Oct 2018 14:01:44 +0200 Subject: [PATCH 03/15] MKLDNN elementwise_mul: CPU tests initially refactored. MKLDNN mul test for broadcast added --- .../operators/elementwise_mul_mkldnn_op.cc | 2 - .../unittests/test_elementwise_add_op.py | 6 --- .../test_elementwise_mul_mkldnn_op.py | 50 +++++++++++++++++++ .../unittests/test_elementwise_mul_op.py | 44 +++++++++++----- 4 files changed, 81 insertions(+), 21 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc index 595a6232da..13e4cc04df 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -85,8 +85,6 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { mul_func_t mul_func = (mul_func_t)mul.getCode(); - auto ptr_x = x_data; - for (int ni = 0; ni < n; ni++) { for (int ci = 0; ci < C; ci++) { for (int hi = 0; hi < h; hi++) { diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index 5aec5d8e38..d71a9c0151 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -43,19 +43,13 @@ class TestElementwiseAddOp(OpTest): self.check_output() def test_check_grad_normal(self): - if self.dtype == np.float16: - return self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005) def test_check_grad_ingore_x(self): - if self.dtype == np.float16: - return self.check_grad( ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) def test_check_grad_ingore_y(self): - if self.dtype == np.float16: - return self.check_grad( ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py new file mode 100644 index 0000000000..a0581d16de --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py @@ -0,0 +1,50 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core +from paddle.fluid.op import Operator +from test_elementwise_mul_op import * + + +class ElementwiseMulMKLDNNOp(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + self.y = np.random.rand(1, 16).astype(self.dtype) + + self.out = x * self.y.reshape(1, 16, 1, 1) + self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index 53409e436c..57ba34f833 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -21,13 +21,24 @@ from paddle.fluid.op import Operator class ElementwiseMulOp(OpTest): + def init_kernel_type(self): + self.use_mkldnn = False + def setUp(self): self.op_type = "elementwise_mul" + self.dtype = np.float32 + self.axis = -1 + self.init_dtype() + self.init_input_output() + self.init_kernel_type() + self.init_axis() + self.inputs = { - 'X': np.random.uniform(0.1, 1, [13, 17]).astype("float64"), - 'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float64") + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) } - self.outputs = {'Out': np.multiply(self.inputs['X'], self.inputs['Y'])} + self.outputs = {'Out': self.out} + self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} def test_check_output(self): self.check_output() @@ -41,6 +52,17 @@ class ElementwiseMulOp(OpTest): def test_check_grad_ingore_y(self): self.check_grad(['X'], 'Out', no_grad_set=set('Y')) + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.multiply(self.x, self.y) + + def init_dtype(self): + pass + + def init_axis(self): + pass + class TestElementwiseMulOp_scalar(ElementwiseMulOp): def setUp(self): @@ -63,17 +85,13 @@ class TestElementwiseMulOp_Vector(ElementwiseMulOp): class TestElementwiseMulOp_broadcast_0(ElementwiseMulOp): - def setUp(self): - self.op_type = "elementwise_mul" - self.inputs = { - 'X': np.random.rand(2, 3, 4).astype(np.float64), - 'Y': np.random.rand(2).astype(np.float64) - } + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(2).astype(self.dtype) + self.out = self.x * self.y.reshape(2, 1, 1) - self.attrs = {'axis': 0} - self.outputs = { - 'Out': self.inputs['X'] * self.inputs['Y'].reshape(2, 1, 1) - } + def init_axis(self): + self.axis = 0 class TestElementwiseMulOp_broadcast_1(ElementwiseMulOp): From 700bcbf74fa5c7b43fa183063e9bbdfc2bd23265 Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Sun, 28 Oct 2018 02:00:34 +0100 Subject: [PATCH 04/15] MKLDNN elementwise_mul: h and w loops implemented in xbyak --- .../operators/elementwise_mul_mkldnn_op.cc | 58 +++++++++++++------ 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc index 13e4cc04df..21716e271d 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -30,16 +30,42 @@ struct vector_mul : public Xbyak::CodeGenerator { // RDI is ptr X // RSI is ptr Y // RDX is ptr Z + // RCX is h + // r8 is w - vmovups(zmm2, ptr[rdi]); + push(rbx); + + xor_(rax, rax); + xor_(r10, r10); vmovups(zmm3, ptr[rsi]); - vmulps(zmm1, zmm2, zmm3); - vmovups(ptr[rdx], zmm1); + L("h_loop"); + xor_(rbx, rbx); + L("w_loop"); + vmovups(zmm2, ptr[rdi + rax]); + vmulps(zmm1, zmm2, zmm3); + vmovups(ptr[rdx + rax], zmm1); + add(rax, 64); + inc(rbx); + cmp(r8, rbx); + jnz("w_loop"); + inc(r10); + cmp(r10, rcx); + jnz("h_loop"); + + pop(rbx); ret(); } }; +void check(const float* x, const float* y, float* z, int w) { + for (int wi = 0; wi < w; wi++) { + for (int i = 0; i < 16; i++) { + z[wi * 16 + i] = x[wi * 16 + i] * y[i]; + } + } +} + template class ElementwiseMulMKLDNNKernel : public framework::OpKernel { public: @@ -65,7 +91,6 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { PADDLE_THROW("Not implemented when post is 1"); } else { // Just check whether it works for RE-Resnext. - PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); int n = x_dims[0]; @@ -81,26 +106,21 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { vector_mul mul; - using mul_func_t = void (*)(const float*, const float*, float*); + using mul_func_t = + void (*)(const float*, const float*, float*, int, int); mul_func_t mul_func = (mul_func_t)mul.getCode(); for (int ni = 0; ni < n; ni++) { for (int ci = 0; ci < C; ci++) { - for (int hi = 0; hi < h; hi++) { - for (int wi = 0; wi < w; wi++) { - auto ptr_x = x_data + ni * C * h * w * simd_width + - ci * h * w * simd_width + hi * w * simd_width + - wi * simd_width; - auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; - - auto ptr_z = z_data + ni * C * h * w * simd_width + - ci * h * w * simd_width + hi * w * simd_width + - wi * simd_width; - - mul_func(ptr_x, ptr_y, ptr_z); - } - } + auto ptr_x = + x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + auto ptr_z = + z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + + mul_func(ptr_x, ptr_y, ptr_z, h, w); } } } From 4e54ab76ecb7e86dcfbfd59824bc2c5593513809 Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Tue, 6 Nov 2018 10:57:15 +0100 Subject: [PATCH 05/15] Add HasAttr method to Operator --- paddle/fluid/framework/operator.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 40b0130b26..6918e030bf 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -100,6 +100,7 @@ class OperatorBase { const std::string& Type() const { return type_; } + bool HasAttr(const std::string& name) const { return attrs_.count(name); } template inline const T& Attr(const std::string& name) const { PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", From ed31936ba1343a84460d2fd1883f75e0951ce353 Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Tue, 6 Nov 2018 11:04:39 +0100 Subject: [PATCH 06/15] MKLDNN elementwise_mul: Support NCHW, update UT --- .../operators/elementwise/elementwise_op.h | 14 ++ .../operators/elementwise_mul_mkldnn_op.cc | 124 +++++++++++++----- .../test_elementwise_mul_mkldnn_op.py | 29 +++- 3 files changed, 135 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index f01f67692e..16d919689c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -97,6 +97,20 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { .EqualGreaterThan(-1); AddAttr("use_mkldnn", "(bool, default false). Used by MKLDNN.") .SetDefault(false); + AddAttr( + "x_data_format", + "(string, default NCHW) Only used in mkldnn" + "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " + "Defaults to \"\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault(""); + AddAttr( + "y_data_format", + "(string, default \"\") Only used in mkldnn" + "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " + "Defaults to \"\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault(""); AddComment(string::Sprintf(R"DOC( Elementwise %s Operator diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc index 21716e271d..d66c58bd45 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" @@ -24,6 +25,7 @@ namespace paddle { namespace operators { using framework::DataLayout; +using mkldnn::memory; struct vector_mul : public Xbyak::CodeGenerator { vector_mul() { @@ -66,6 +68,33 @@ void check(const float* x, const float* y, float* z, int w) { } } +static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) { + std::transform(format.begin(), format.end(), format.begin(), ::tolower); + + if(!format.compare("nchw")) { + return memory::format::nchw; + } else if(!format.compare("nchw16c")) { + return memory::format::nChw16c; + } else if(!format.compare("nchw8c")) { + return memory::format::nChw8c; + } else if(!format.compare("nhwc")) { + return memory::format::nhwc; + } else { + return memory::format::any; + } +} + +static void UpdateDataFormat(const framework::ExecutionContext& ctx, + framework::Tensor* tensor, const char* attribute) { + if(ctx.op().HasAttr(attribute)) { + auto format_as_string = ctx.Attr(attribute); + auto format = StringToMKLDNNFormat(format_as_string); + if (format != memory::format::any) { + tensor->set_format(format); + } + } +} + template class ElementwiseMulMKLDNNKernel : public framework::OpKernel { public: @@ -83,52 +112,87 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { auto x_dims = x->dims(); auto y_dims_untrimmed = y->dims(); - if (x_dims != y_dims_untrimmed) { - int pre, n, post; - get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); + UpdateDataFormat(ctx, (Tensor*)x, "x_data_format"); + UpdateDataFormat(ctx, (Tensor*)y, "y_data_format"); - if (post == 1) { - PADDLE_THROW("Not implemented when post is 1"); - } else { - // Just check whether it works for RE-Resnext. - PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); + if (x->format() == memory::format::nChw16c && y->format() == memory::format::nc) { + if (x_dims != y_dims_untrimmed) { + int pre, n, post; + get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); + + if (post == 1) { + PADDLE_THROW("Not implemented when post is 1"); + } else { + // Just check whether it works for RE-Resnext. + PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); - int n = x_dims[0]; - int c = x_dims[1]; - int h = x_dims[2]; - int w = x_dims[3]; + int n = x_dims[0]; + int c = x_dims[1]; + int h = x_dims[2]; + int w = x_dims[3]; - PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, - "Y should be in nc format"); + PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, + "Y should be in nc format"); - constexpr int simd_width = 16; - int C = c / simd_width; + constexpr int simd_width = 16; + int C = c / simd_width; - vector_mul mul; + vector_mul mul; - using mul_func_t = - void (*)(const float*, const float*, float*, int, int); + using mul_func_t = + void (*)(const float *, const float *, float *, int, int); - mul_func_t mul_func = (mul_func_t)mul.getCode(); + mul_func_t mul_func = (mul_func_t) mul.getCode(); - for (int ni = 0; ni < n; ni++) { - for (int ci = 0; ci < C; ci++) { - auto ptr_x = - x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + for (int ni = 0; ni < n; ni++) { + for (int ci = 0; ci < C; ci++) { + auto ptr_x = + x_data + ni * C * h * w * simd_width + + ci * h * w * simd_width; - auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; - auto ptr_z = - z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + auto ptr_z = + z_data + ni * C * h * w * simd_width + + ci * h * w * simd_width; - mul_func(ptr_x, ptr_y, ptr_z, h, w); + mul_func(ptr_x, ptr_y, ptr_z, h, w); + } } } + + z->set_layout(DataLayout::kMKLDNN); + z->set_format(x->format()); + } else { + PADDLE_THROW("Not implemented when dims are equal"); } + } else { + // Fallback to naive version: + auto mul_func = [](T a, T b) -> T { return a * b; }; + + TransformFunctor + functor( + x, y, z, + ctx.template device_context(), + mul_func); + axis = (axis == -1 ? x_dims.size() - y_dims_untrimmed.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed); + axis = (y_dims.size() == 0) ? x_dims.size() : axis; + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); + + if (post == 1) { + functor.RunRowWise(n, pre); + } else { + functor.RunMidWise(n, pre, post); + } z->set_layout(DataLayout::kMKLDNN); z->set_format(x->format()); - } else { - PADDLE_THROW("Not implemented when dims are equal"); } } }; diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py index a0581d16de..a89f439664 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py @@ -20,8 +20,7 @@ import paddle.fluid.core as core from paddle.fluid.op import Operator from test_elementwise_mul_op import * - -class ElementwiseMulMKLDNNOp(ElementwiseMulOp): +class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): def init_input_output(self): x = np.random.rand(1, 16, 2, 2).astype(self.dtype) self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) @@ -30,6 +29,11 @@ class ElementwiseMulMKLDNNOp(ElementwiseMulOp): self.out = x * self.y.reshape(1, 16, 1, 1) self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + def setUp(self): + super(TestElementwiseMulMKLDNNOp_BroadcastNCHW16c, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nc" + def init_kernel_type(self): self.use_mkldnn = True @@ -45,6 +49,27 @@ class ElementwiseMulMKLDNNOp(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass +class TestElementwiseMulMKLDNNOp_UnsupportedFormat(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = np.random.rand(1, 16).astype(self.dtype) + + self.out = self.x * self.y.reshape(1, 16, 1, 1) + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass if __name__ == '__main__': unittest.main() From d14858e4baf0aaeeaa9ccd33623958de6f4a6bd4 Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Tue, 6 Nov 2018 12:52:44 +0100 Subject: [PATCH 07/15] MKLDNN elementwise_mul: Parallelize mul --- paddle/fluid/operators/elementwise_mul_mkldnn_op.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc index d66c58bd45..36e88cd789 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -144,6 +144,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { mul_func_t mul_func = (mul_func_t) mul.getCode(); + #pragma omp parallel for collapse(2) for (int ni = 0; ni < n; ni++) { for (int ci = 0; ci < C; ci++) { auto ptr_x = From f820573b9c6ffee12aaf64b656d902dc0c9532f5 Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Wed, 7 Nov 2018 11:37:27 +0100 Subject: [PATCH 08/15] MKLDNN elementwise_mul: Add UTs --- .../test_elementwise_mul_mkldnn_op.py | 119 +++++++++++++++++- 1 file changed, 118 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py index a89f439664..a008979801 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py @@ -49,7 +49,37 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass -class TestElementwiseMulMKLDNNOp_UnsupportedFormat(ElementwiseMulOp): +@unittest.skip("Not implemented yet.") +class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 8, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 8, 2, 2) + self.y = np.random.rand(1, 8).astype(self.dtype) + + self.out = x * self.y.reshape(1, 8, 1, 1) + self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 8, 2, 2) + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_BroadcastNCHW8c, self).setUp() + self.attrs["x_data_format"] = "nchw8c" + self.attrs["y_data_format"] = "nc" + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + +class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp): def init_input_output(self): self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) self.y = np.random.rand(1, 16).astype(self.dtype) @@ -71,5 +101,92 @@ class TestElementwiseMulMKLDNNOp_UnsupportedFormat(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass +class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = self.x * self.y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nchw16c" + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + +class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = self.x * self.y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackNoReorders, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nchw16c" + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + +@unittest.skip("Not implemented yet.") +class TestElementwiseMulMKLDNNOp_FallbackWithReorder(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = self.x * y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp() + self.attrs["x_data_format"] = "nchw" + self.attrs["y_data_format"] = "nchw16c" + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + if __name__ == '__main__': unittest.main() From 49b09327f673598dfaeac4bcc2613d50228b2a73 Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Fri, 9 Nov 2018 15:21:07 +0100 Subject: [PATCH 09/15] MKLDNN elementwise_mul: Reorder on non-nchw input, fallback on non-16 divisable fm test=develop --- .../operators/elementwise_mul_mkldnn_op.cc | 111 ++++++++++++------ .../test_elementwise_mul_mkldnn_op.py | 62 +++++++++- 2 files changed, 131 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc index 36e88cd789..58aadd0033 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -95,6 +95,26 @@ static void UpdateDataFormat(const framework::ExecutionContext& ctx, } } +template +static void ReorderInput(framework::Tensor* tensor, + const platform::Place& place, + const mkldnn::engine& engine, + bool isFourDim) { + using platform::to_void_cast; + auto dims = paddle::framework::vectorize2int(tensor->dims()); + framework::Tensor out_tensor; + out_tensor.Resize(tensor->dims()); + out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc); + out_tensor.set_layout(tensor->layout()); + mkldnn::memory input_memory = {{{dims, platform::MKLDNNGetDataType(), + tensor->format()}, engine}, to_void_cast(tensor->data())}; + mkldnn::memory output_memory = {{{dims, platform::MKLDNNGetDataType(), + out_tensor.format()}, engine}, + to_void_cast(out_tensor.mutable_data(place))}; + platform::Reorder(input_memory, output_memory); + tensor->ShareDataWith(out_tensor); +} + template class ElementwiseMulMKLDNNKernel : public framework::OpKernel { public: @@ -111,63 +131,78 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { auto x_dims = x->dims(); auto y_dims_untrimmed = y->dims(); + auto x_int_dims = paddle::framework::vectorize2int(x_dims); UpdateDataFormat(ctx, (Tensor*)x, "x_data_format"); UpdateDataFormat(ctx, (Tensor*)y, "y_data_format"); - if (x->format() == memory::format::nChw16c && y->format() == memory::format::nc) { - if (x_dims != y_dims_untrimmed) { - int pre, n, post; - get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); + const bool are_dims_divisable = !(x_int_dims[1] % 16); + const bool is_x_format_correct = x->format() == memory::format::nChw16c; + const bool is_y_format_correct = y->format() == memory::format::nc; + if (is_x_format_correct && is_y_format_correct && are_dims_divisable) { + int pre, n, post; + get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); - if (post == 1) { - PADDLE_THROW("Not implemented when post is 1"); - } else { - // Just check whether it works for RE-Resnext. - PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); + if (post == 1) { + PADDLE_THROW("Not implemented when post is 1"); + } else { + // Just check whether it works for RE-Resnext. + PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); - int n = x_dims[0]; - int c = x_dims[1]; - int h = x_dims[2]; - int w = x_dims[3]; + int n = x_dims[0]; + int c = x_dims[1]; + int h = x_dims[2]; + int w = x_dims[3]; - PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, - "Y should be in nc format"); + PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, + "Y should be in nc format"); - constexpr int simd_width = 16; - int C = c / simd_width; + constexpr int simd_width = 16; + int C = c / simd_width; - vector_mul mul; + vector_mul mul; - using mul_func_t = - void (*)(const float *, const float *, float *, int, int); + using mul_func_t = + void (*)(const float *, const float *, float *, int, int); - mul_func_t mul_func = (mul_func_t) mul.getCode(); + mul_func_t mul_func = (mul_func_t) mul.getCode(); - #pragma omp parallel for collapse(2) - for (int ni = 0; ni < n; ni++) { - for (int ci = 0; ci < C; ci++) { - auto ptr_x = - x_data + ni * C * h * w * simd_width + - ci * h * w * simd_width; + #pragma omp parallel for collapse(2) + for (int ni = 0; ni < n; ni++) { + for (int ci = 0; ci < C; ci++) { + auto ptr_x = + x_data + ni * C * h * w * simd_width + + ci * h * w * simd_width; - auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; - auto ptr_z = - z_data + ni * C * h * w * simd_width + - ci * h * w * simd_width; + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + auto ptr_z = + z_data + ni * C * h * w * simd_width + + ci * h * w * simd_width; - mul_func(ptr_x, ptr_y, ptr_z, h, w); - } + mul_func(ptr_x, ptr_y, ptr_z, h, w); } } - - z->set_layout(DataLayout::kMKLDNN); - z->set_format(x->format()); - } else { - PADDLE_THROW("Not implemented when dims are equal"); } + + z->set_layout(DataLayout::kMKLDNN); + z->set_format(x->format()); } else { // Fallback to naive version: + const bool are_inputs_in_same_format = x->format() == y->format(); + const bool is_x_nchw= x->format() == memory::format::nchw; + const bool is_x_nc = x->format() == memory::format::nc; + const bool is_y_nchw= y->format() == memory::format::nchw; + const bool is_y_nc = y->format() == memory::format::nc; + if(!are_inputs_in_same_format) { + using platform::MKLDNNDeviceContext; + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + if(!(is_x_nchw || is_x_nc)) + ReorderInput((Tensor*)x, ctx.GetPlace(), mkldnn_engine, x->dims().size() == 4); + if(!(is_y_nchw || is_y_nc)) + ReorderInput((Tensor*)y, ctx.GetPlace(), mkldnn_engine, y->dims().size() == 4); + } + auto mul_func = [](T a, T b) -> T { return a * b; }; TransformFunctor Date: Fri, 9 Nov 2018 15:43:55 +0100 Subject: [PATCH 10/15] Add Sand3r- to AUTHORS.md test=develop --- AUTHORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS.md b/AUTHORS.md index 4060f75613..54a1097b50 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -42,6 +42,7 @@ | QiJune | Jun Qi | | qingqing01 | Qing-Qing Dang | | reyoung | Yang Yu | +| Sand3r- | Michal Gallus | | Superjom | Chun-Wei Yan | | tensor-tang | Jian Tang | | tianbingsz | Tian-Bing Xu | From 08f63c4d1253007ee6290f8dfab3c31195940168 Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Tue, 13 Nov 2018 09:12:10 +0100 Subject: [PATCH 11/15] MKLDNN elementwise_mul: Lint changes to UT & integration test=develop --- .../operators/elementwise/elementwise_op.h | 24 ++++----- .../operators/elementwise_mul_mkldnn_op.cc | 54 +++++++++---------- .../test_elementwise_mul_mkldnn_op.py | 12 ++++- 3 files changed, 50 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 16d919689c..85a7817be9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -98,19 +98,19 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("use_mkldnn", "(bool, default false). Used by MKLDNN.") .SetDefault(false); AddAttr( - "x_data_format", - "(string, default NCHW) Only used in mkldnn" - "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " - "Defaults to \"\". Specify the data format of the output data, " - "the input will be transformed automatically. ") - .SetDefault(""); + "x_data_format", + "(string, default NCHW) Only used in mkldnn" + "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " + "Defaults to \"\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault(""); AddAttr( - "y_data_format", - "(string, default \"\") Only used in mkldnn" - "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " - "Defaults to \"\". Specify the data format of the output data, " - "the input will be transformed automatically. ") - .SetDefault(""); + "y_data_format", + "(string, default \"\") Only used in mkldnn" + "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " + "Defaults to \"\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault(""); AddComment(string::Sprintf(R"DOC( Elementwise %s Operator diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc index 58aadd0033..6371c9f839 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -71,13 +71,13 @@ void check(const float* x, const float* y, float* z, int w) { static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) { std::transform(format.begin(), format.end(), format.begin(), ::tolower); - if(!format.compare("nchw")) { + if (!format.compare("nchw")) { return memory::format::nchw; - } else if(!format.compare("nchw16c")) { + } else if (!format.compare("nchw16c")) { return memory::format::nChw16c; - } else if(!format.compare("nchw8c")) { + } else if (!format.compare("nchw8c")) { return memory::format::nChw8c; - } else if(!format.compare("nhwc")) { + } else if (!format.compare("nhwc")) { return memory::format::nhwc; } else { return memory::format::any; @@ -85,8 +85,8 @@ static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) { } static void UpdateDataFormat(const framework::ExecutionContext& ctx, - framework::Tensor* tensor, const char* attribute) { - if(ctx.op().HasAttr(attribute)) { + framework::Tensor* tensor, const char* attribute) { + if (ctx.op().HasAttr(attribute)) { auto format_as_string = ctx.Attr(attribute); auto format = StringToMKLDNNFormat(format_as_string); if (format != memory::format::any) { @@ -98,19 +98,19 @@ static void UpdateDataFormat(const framework::ExecutionContext& ctx, template static void ReorderInput(framework::Tensor* tensor, const platform::Place& place, - const mkldnn::engine& engine, - bool isFourDim) { + const mkldnn::engine& engine, bool isFourDim) { using platform::to_void_cast; auto dims = paddle::framework::vectorize2int(tensor->dims()); framework::Tensor out_tensor; out_tensor.Resize(tensor->dims()); out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc); out_tensor.set_layout(tensor->layout()); - mkldnn::memory input_memory = {{{dims, platform::MKLDNNGetDataType(), - tensor->format()}, engine}, to_void_cast(tensor->data())}; - mkldnn::memory output_memory = {{{dims, platform::MKLDNNGetDataType(), - out_tensor.format()}, engine}, - to_void_cast(out_tensor.mutable_data(place))}; + mkldnn::memory input_memory = { + {{dims, platform::MKLDNNGetDataType(), tensor->format()}, engine}, + to_void_cast(tensor->data())}; + mkldnn::memory output_memory = { + {{dims, platform::MKLDNNGetDataType(), out_tensor.format()}, engine}, + to_void_cast(out_tensor.mutable_data(place))}; platform::Reorder(input_memory, output_memory); tensor->ShareDataWith(out_tensor); } @@ -163,21 +163,19 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { vector_mul mul; using mul_func_t = - void (*)(const float *, const float *, float *, int, int); + void (*)(const float*, const float*, float*, int, int); - mul_func_t mul_func = (mul_func_t) mul.getCode(); + mul_func_t mul_func = (mul_func_t)mul.getCode(); - #pragma omp parallel for collapse(2) +#pragma omp parallel for collapse(2) for (int ni = 0; ni < n; ni++) { for (int ci = 0; ci < C; ci++) { auto ptr_x = - x_data + ni * C * h * w * simd_width + - ci * h * w * simd_width; + x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; auto ptr_z = - z_data + ni * C * h * w * simd_width + - ci * h * w * simd_width; + z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; mul_func(ptr_x, ptr_y, ptr_z, h, w); } @@ -189,18 +187,20 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { } else { // Fallback to naive version: const bool are_inputs_in_same_format = x->format() == y->format(); - const bool is_x_nchw= x->format() == memory::format::nchw; + const bool is_x_nchw = x->format() == memory::format::nchw; const bool is_x_nc = x->format() == memory::format::nc; - const bool is_y_nchw= y->format() == memory::format::nchw; + const bool is_y_nchw = y->format() == memory::format::nchw; const bool is_y_nc = y->format() == memory::format::nc; - if(!are_inputs_in_same_format) { + if (!are_inputs_in_same_format) { using platform::MKLDNNDeviceContext; auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); - if(!(is_x_nchw || is_x_nc)) - ReorderInput((Tensor*)x, ctx.GetPlace(), mkldnn_engine, x->dims().size() == 4); - if(!(is_y_nchw || is_y_nc)) - ReorderInput((Tensor*)y, ctx.GetPlace(), mkldnn_engine, y->dims().size() == 4); + if (!(is_x_nchw || is_x_nc)) + ReorderInput((Tensor*)x, ctx.GetPlace(), mkldnn_engine, + x->dims().size() == 4); + if (!(is_y_nchw || is_y_nc)) + ReorderInput((Tensor*)y, ctx.GetPlace(), mkldnn_engine, + y->dims().size() == 4); } auto mul_func = [](T a, T b) -> T { return a * b; }; diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py index 77d24a81f2..56e2ca849a 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py @@ -20,6 +20,7 @@ import paddle.fluid.core as core from paddle.fluid.op import Operator from test_elementwise_mul_op import * + class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): def init_input_output(self): x = np.random.rand(1, 16, 2, 2).astype(self.dtype) @@ -49,7 +50,9 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass -@unittest.skip("Not implemented yet.") # TODO(mgallus): enable when implemented. + +@unittest.skip( + "Not implemented yet.") # TODO(mgallus): enable when implemented. class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp): def init_input_output(self): x = np.random.rand(1, 8, 2, 2).astype(self.dtype) @@ -79,6 +82,7 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass + class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp): def init_input_output(self): self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) @@ -101,6 +105,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass + class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp): def init_input_output(self): x = np.random.rand(1, 16, 2, 2).astype(self.dtype) @@ -130,6 +135,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass + class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp): def init_input_output(self): x = np.random.rand(1, 16, 2, 2).astype(self.dtype) @@ -159,6 +165,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass + class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp): def init_input_output(self): self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) @@ -187,6 +194,7 @@ class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass + class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp): def init_input_output(self): self.y = np.random.rand(1, 16, 2, 2).astype(self.dtype) @@ -215,6 +223,7 @@ class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass + class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp): def init_input_output(self): self.x = np.random.rand(1, 16).astype(self.dtype) @@ -242,5 +251,6 @@ class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp): def test_check_grad_ingore_y(self): pass + if __name__ == '__main__': unittest.main() From 785066eb8aa1ec552f3d093e8a7aa3d229700572 Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Tue, 13 Nov 2018 12:12:08 +0100 Subject: [PATCH 12/15] MKLDNN elementwise_mul: Check if AVX512 is available test=develop --- paddle/fluid/operators/elementwise_mul_mkldnn_op.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc index 6371c9f839..216c7ed9c6 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -136,10 +136,13 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { UpdateDataFormat(ctx, (Tensor*)x, "x_data_format"); UpdateDataFormat(ctx, (Tensor*)y, "y_data_format"); + Xbyak::util::Cpu cpu; + const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F); const bool are_dims_divisable = !(x_int_dims[1] % 16); const bool is_x_format_correct = x->format() == memory::format::nChw16c; const bool is_y_format_correct = y->format() == memory::format::nc; - if (is_x_format_correct && is_y_format_correct && are_dims_divisable) { + if (is_x_format_correct && is_y_format_correct && are_dims_divisable && + is_avx512_enabled) { int pre, n, post; get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); From 99e3e36a5701bf15e9a18f01b19a60ced78137aa Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Tue, 13 Nov 2018 15:03:14 +0100 Subject: [PATCH 13/15] MKLDNN elementwise_mul: Disable UT for CUDA test=develop --- python/paddle/fluid/tests/unittests/op_test.py | 4 +++- .../tests/unittests/test_elementwise_mul_mkldnn_op.py | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 690c4cf0ad..c195a28e45 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -362,7 +362,9 @@ class OpTest(unittest.TestCase): else: return [] places = [fluid.CPUPlace()] - if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): + cpu_only = self._cpu_only if hasattr(self, '_cpu_only') else False + if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type)\ + and not cpu_only: places.append(core.CUDAPlace(0)) return places diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py index 56e2ca849a..536e9a1c58 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py @@ -34,6 +34,7 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): super(TestElementwiseMulMKLDNNOp_BroadcastNCHW16c, self).setUp() self.attrs["x_data_format"] = "nchw16c" self.attrs["y_data_format"] = "nc" + self._cpu_only = True def init_kernel_type(self): self.use_mkldnn = True @@ -66,6 +67,7 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp): super(TestElementwiseMulMKLDNNOp_BroadcastNCHW8c, self).setUp() self.attrs["x_data_format"] = "nchw8c" self.attrs["y_data_format"] = "nc" + self._cpu_only = True def init_kernel_type(self): self.use_mkldnn = True @@ -119,6 +121,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp): super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp() self.attrs["x_data_format"] = "nchw16c" self.attrs["y_data_format"] = "nchw16c" + self._cpu_only = True def init_kernel_type(self): self.use_mkldnn = True @@ -149,6 +152,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp): super(TestElementwiseMulMKLDNNOp_FallbackNoReorders, self).setUp() self.attrs["x_data_format"] = "nchw16c" self.attrs["y_data_format"] = "nchw16c" + self._cpu_only = True def init_kernel_type(self): self.use_mkldnn = True @@ -178,6 +182,7 @@ class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp): super(TestElementwiseMulMKLDNNOp_FallbackWithReorder1, self).setUp() self.attrs["x_data_format"] = "nchw" self.attrs["y_data_format"] = "nchw16c" + self._cpu_only = True def init_kernel_type(self): self.use_mkldnn = True @@ -207,6 +212,7 @@ class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp): super(TestElementwiseMulMKLDNNOp_FallbackWithReorder2, self).setUp() self.attrs["x_data_format"] = "nchw16c" self.attrs["y_data_format"] = "nchw" + self._cpu_only = True def init_kernel_type(self): self.use_mkldnn = True @@ -235,6 +241,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp): super(TestElementwiseMulMKLDNNOp_FallbackNoReorders2, self).setUp() self.attrs["x_data_format"] = "nc" self.attrs["y_data_format"] = "nc" + self._cpu_only = True def init_kernel_type(self): self.use_mkldnn = True From c69c41604e29dfc8b463cb79fc4cc1864ba15372 Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Thu, 15 Nov 2018 15:14:48 +0100 Subject: [PATCH 14/15] MKLDNN elementwise_mul: Move Kernel to KernelPool to avoid segfaults test=develop --- .../elementwise_mul_mkldnn_op.cc | 61 +++---------------- paddle/fluid/operators/math/jit_code.h | 36 +++++++++++ paddle/fluid/operators/math/jit_kernel.h | 9 +++ .../fluid/operators/math/jit_kernel_blas.cc | 41 +++++++++++++ 4 files changed, 95 insertions(+), 52 deletions(-) rename paddle/fluid/operators/{ => elementwise}/elementwise_mul_mkldnn_op.cc (85%) diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc similarity index 85% rename from paddle/fluid/operators/elementwise_mul_mkldnn_op.cc rename to paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc index 216c7ed9c6..10290a4aef 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include "paddle/fluid/operators/elementwise_op.h" -#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/platform/mkldnn_helper.h" -#include "xbyak/xbyak.h" -#include "xbyak/xbyak_util.h" +#include "paddle/fluid/operators/math/jit_kernel.h" +#include "xbyak.h" +#include "xbyak_util.h" namespace paddle { namespace operators { @@ -27,47 +28,6 @@ namespace operators { using framework::DataLayout; using mkldnn::memory; -struct vector_mul : public Xbyak::CodeGenerator { - vector_mul() { - // RDI is ptr X - // RSI is ptr Y - // RDX is ptr Z - // RCX is h - // r8 is w - - push(rbx); - - xor_(rax, rax); - xor_(r10, r10); - vmovups(zmm3, ptr[rsi]); - - L("h_loop"); - xor_(rbx, rbx); - L("w_loop"); - vmovups(zmm2, ptr[rdi + rax]); - vmulps(zmm1, zmm2, zmm3); - vmovups(ptr[rdx + rax], zmm1); - add(rax, 64); - inc(rbx); - cmp(r8, rbx); - jnz("w_loop"); - inc(r10); - cmp(r10, rcx); - jnz("h_loop"); - - pop(rbx); - ret(); - } -}; - -void check(const float* x, const float* y, float* z, int w) { - for (int wi = 0; wi < w; wi++) { - for (int i = 0; i < 16; i++) { - z[wi * 16 + i] = x[wi * 16 + i] * y[i]; - } - } -} - static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) { std::transform(format.begin(), format.end(), format.begin(), ::tolower); @@ -163,12 +123,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { constexpr int simd_width = 16; int C = c / simd_width; - vector_mul mul; - - using mul_func_t = - void (*)(const float*, const float*, float*, int, int); - - mul_func_t mul_func = (mul_func_t)mul.getCode(); + const auto& multiply = + math::jitkernel::KernelPool::Instance() + .template Get>(n); #pragma omp parallel for collapse(2) for (int ni = 0; ni < n; ni++) { @@ -180,7 +137,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { auto ptr_z = z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - mul_func(ptr_x, ptr_y, ptr_z, h, w); + multiply->Compute(ptr_x, ptr_y, ptr_z, h, w); } } } diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 71205b211b..dbfe629013 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -156,6 +156,42 @@ class VActJitCode : public JitCode { ymm_t ymm_dst = ymm_t(1); }; +#ifdef PADDLE_WITH_MKLDNN +struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator { + explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024) + : Xbyak::CodeGenerator(code_size) { + // RDI is ptr x_input + // RSI is ptr y_input + // RDX is ptr output + // RCX is height + // r8 is width + + push(rbx); + + xor_(rax, rax); + xor_(r10, r10); + vmovups(zmm3, ptr[rsi]); + + L("h_loop"); + xor_(rbx, rbx); + L("w_loop"); + vmovups(zmm2, ptr[rdi + rax]); + vmulps(zmm1, zmm2, zmm3); + vmovups(ptr[rdx + rax], zmm1); + add(rax, 64); + inc(rbx); + cmp(r8, rbx); + jnz("w_loop"); + inc(r10); + cmp(r10, rcx); + jnz("h_loop"); + + pop(rbx); + ret(); + } +}; +#endif + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 4d8d3cd79a..110de3b140 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -94,6 +94,15 @@ class VAddBiasKernel : public Kernel { void (*Compute)(const T *, const T *, T *, int); }; +#ifdef PADDLE_WITH_MKLDNN +template +class EltwiseMulnChw16cNCKernel : public Kernel { + public: + // nChw16c = nChw16c .* NC + void (*Compute)(const float *, const float *, float *, int, int); +}; +#endif + template class VActKernel : public Kernel { public: diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 36a50f2043..a143b51439 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -226,6 +226,44 @@ bool VAddKernelImpl::useMKL(int d) { } #endif +#ifdef PADDLE_WITH_MKLDNN +/* EltwiseMul for nChw16c & NC inputs JitKernel */ +template +class EltwiseMulnChw16cNCKernelImpl + : public math::jitkernel::EltwiseMulnChw16cNCKernel { + public: + JITKERNEL_DECLARE_STATIC_FUNC; + explicit EltwiseMulnChw16cNCKernelImpl(int d) + : EltwiseMulnChw16cNCKernel() { + using mul_func_t = void (*)(const float*, const float*, float*, int, int); +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + // roughly estimate the size of code + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; + sz = sz > 4096 ? sz : 4096; + jitcode_.reset(new gen::EltwiseMulnChw16cNC(sz)); + this->Compute = (mul_func_t)jitcode_->getCode(); + return; + } +#endif + PADDLE_THROW( + "This kernel shouldn't be used in Non-Xbyak, Non-MKL-DNN " + "environemnt"); + } + +#ifdef PADDLE_WITH_XBYAK + + private: + std::unique_ptr jitcode_{nullptr}; +}; + +template <> +bool EltwiseMulnChw16cNCKernelImpl::useJIT(int d) { + return true; +} +#endif +#endif + /* VAddRelu JitKernel */ template class VAddReluKernelImpl : public VAddReluKernel { @@ -394,6 +432,9 @@ REGISTER_JITKERNEL(vscal, VScalKernel); REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); REGISTER_JITKERNEL(vrelu, VReluKernel); REGISTER_JITKERNEL(videntity, VIdentityKernel); +#ifdef PADDLE_WITH_MKLDNN +REGISTER_JITKERNEL(eltwise_mul_nchw16c, EltwiseMulnChw16cNCKernel); +#endif } // namespace jitkernel } // namespace math From def272cf42e9b2ebf529b39f183874a1dede9c2a Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Fri, 16 Nov 2018 15:29:15 +0100 Subject: [PATCH 15/15] MKLDNN elementwise_mul: Revert changes to eltwise_add tests --- .../paddle/fluid/tests/unittests/test_elementwise_add_op.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index d71a9c0151..5aec5d8e38 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -43,13 +43,19 @@ class TestElementwiseAddOp(OpTest): self.check_output() def test_check_grad_normal(self): + if self.dtype == np.float16: + return self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005) def test_check_grad_ingore_x(self): + if self.dtype == np.float16: + return self.check_grad( ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) def test_check_grad_ingore_y(self): + if self.dtype == np.float16: + return self.check_grad( ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))