MKLDNN elementwis_add with default broadcast operations (#11544)
* elementwise_add with bcast: Brian's implementation by Brian added, with default bcasts * elementwise_add with bcast: GetExpectedKernelType added to elementwise_op * elementwise_add with bcast: use_mkldnn attribute added * elementwise_add with bcast: changes after review and some formatting * elementwise_add with bcast: changes after style check * elementwise_add with bcast: changes after style check cont. * elementwise_add with bcast: MKLDNN unittests added * elementwise_add with bcast: original unittests with use_mkldnn flag * elementwise_add with bcast: handling of MKLDNN format corrected * elementwise_add with bcast: setting MKLDNN format turned into lambda * elementwise_add with bcast: MKDNN format setting turned into separate function * elementwise_add with bcast: condition for choosing MKLDNN simplified * elementwise_add with bcast: fix for MKLDNN format set incorrectly in bcasts * elementwise_add with bcast: changes in unittests for broadcasts * elementwise_add with bcast: fixes in unittests regarding dimensions * elementwise_add with bcast: bring back correct format setting in mklml grad path * elementwise_add with bcast: fixed compilation errorport
parent
67ab324090
commit
e26f51ce74
@ -0,0 +1,190 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/memory/memcpy.h"
|
||||
#include "paddle/fluid/operators/elementwise_add_op.h"
|
||||
#include "paddle/fluid/operators/elementwise_op_function.h"
|
||||
|
||||
#include "paddle/fluid/platform/mkldnn_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::DataLayout;
|
||||
using framework::Tensor;
|
||||
using mkldnn::memory;
|
||||
using mkldnn::reorder;
|
||||
using mkldnn::primitive;
|
||||
using mkldnn::stream;
|
||||
using mkldnn::sum;
|
||||
|
||||
template <typename T>
|
||||
class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto& dev_ctx =
|
||||
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
|
||||
const auto& mkldnn_engine = dev_ctx.GetEngine();
|
||||
|
||||
auto* x = ctx.Input<Tensor>("X");
|
||||
auto* y = ctx.Input<Tensor>("Y");
|
||||
auto* z = ctx.Output<Tensor>("Out");
|
||||
const T* x_data = x->data<T>();
|
||||
const T* y_data = y->data<T>();
|
||||
T* z_data = z->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int axis = ctx.Attr<int>("axis");
|
||||
|
||||
auto x_dims = x->dims();
|
||||
auto y_dims = y->dims();
|
||||
auto z_dims = z->dims();
|
||||
|
||||
// Execute default elementwise_add operator when
|
||||
// broadcast operations need to performed.
|
||||
if (x_dims != y_dims) {
|
||||
auto sum_func = [](T a, T b) -> T { return a + b; };
|
||||
|
||||
TransformFunctor<decltype(sum_func), T,
|
||||
paddle::platform::CPUDeviceContext, T>
|
||||
functor(
|
||||
x, y, z,
|
||||
ctx.template device_context<paddle::platform::CPUDeviceContext>(),
|
||||
sum_func);
|
||||
|
||||
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
|
||||
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
|
||||
"Axis should be in range [0, x_dims)");
|
||||
|
||||
trim_trailing_singular_dims(&y_dims);
|
||||
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
|
||||
|
||||
int pre, n, post;
|
||||
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
|
||||
|
||||
if (post == 1) {
|
||||
functor.RunRowWise(n, pre);
|
||||
} else {
|
||||
functor.RunMidWise(n, pre, post);
|
||||
}
|
||||
z->set_layout(DataLayout::kMKLDNN);
|
||||
z->set_format(x->format());
|
||||
} else {
|
||||
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
|
||||
x->format() != memory::format::format_undef,
|
||||
"Wrong layout/format set for X tensor");
|
||||
PADDLE_ENFORCE(y->layout() == DataLayout::kMKLDNN &&
|
||||
y->format() != memory::format::format_undef,
|
||||
"Wrong layout/format set for X tensor");
|
||||
|
||||
std::vector<int> src_x_tz = framework::vectorize2int(x_dims);
|
||||
std::vector<int> src_y_tz = framework::vectorize2int(y_dims);
|
||||
std::vector<int> dst_tz = framework::vectorize2int(z_dims);
|
||||
|
||||
std::vector<memory::primitive_desc> srcs_pd;
|
||||
std::vector<memory> srcs;
|
||||
std::vector<float> scales = {1.0f, 1.0f};
|
||||
|
||||
auto src_x_pd = memory::primitive_desc(
|
||||
{{src_x_tz}, memory::data_type::f32, x->format()}, mkldnn_engine);
|
||||
auto src_y_pd = memory::primitive_desc(
|
||||
{{src_y_tz}, memory::data_type::f32, y->format()}, mkldnn_engine);
|
||||
auto src_x_memory =
|
||||
memory(src_x_pd, paddle::platform::to_void_cast(x_data));
|
||||
auto src_y_memory =
|
||||
memory(src_y_pd, paddle::platform::to_void_cast(y_data));
|
||||
|
||||
srcs_pd.push_back(src_x_pd);
|
||||
srcs_pd.push_back(src_y_pd);
|
||||
srcs.push_back(src_x_memory);
|
||||
srcs.push_back(src_y_memory);
|
||||
|
||||
auto dst_md =
|
||||
memory::desc({dst_tz}, memory::data_type::f32, memory::format::any);
|
||||
|
||||
// create primitive descriptor for sum
|
||||
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd);
|
||||
|
||||
// create mkldnn memory for dst
|
||||
memory dst_memory = memory(sum_pd.dst_primitive_desc(), z_data);
|
||||
|
||||
std::vector<primitive::at> inputs;
|
||||
inputs.push_back(srcs[0]);
|
||||
inputs.push_back(srcs[1]);
|
||||
|
||||
// create sum primitive
|
||||
auto sum_prim = sum(sum_pd, inputs, dst_memory);
|
||||
|
||||
std::vector<primitive> pipeline;
|
||||
pipeline.push_back(sum_prim);
|
||||
stream(stream::kind::eager).submit(pipeline).wait();
|
||||
|
||||
z->set_layout(DataLayout::kMKLDNN);
|
||||
z->set_format(
|
||||
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
auto* x = ctx.Input<Tensor>("X");
|
||||
auto* y = ctx.Input<Tensor>("Y");
|
||||
auto* out = ctx.Input<Tensor>("Out");
|
||||
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
||||
int axis = ctx.Attr<int>("axis");
|
||||
|
||||
auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
|
||||
in->set_layout(DataLayout::kMKLDNN);
|
||||
in->set_format(out->format());
|
||||
};
|
||||
|
||||
if (x->dims() == y->dims()) {
|
||||
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
|
||||
if (dx) {
|
||||
blas.VCOPY(dout->numel(), dout->data<T>(),
|
||||
dx->mutable_data<T>(ctx.GetPlace()));
|
||||
set_mkldnn_format(dx, dout);
|
||||
}
|
||||
|
||||
if (dy) {
|
||||
blas.VCOPY(dout->numel(), dout->data<T>(),
|
||||
dy->mutable_data<T>(ctx.GetPlace()));
|
||||
set_mkldnn_format(dy, dout);
|
||||
}
|
||||
} else {
|
||||
// Execute default kernel when broadcast is needed
|
||||
ElemwiseGradCompute<paddle::platform::CPUDeviceContext, T,
|
||||
IdentityGrad<T>, IdentityGrad<T>>(
|
||||
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
|
||||
IdentityGrad<T>());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_KERNEL(elementwise_add, MKLDNN, ::paddle::platform::CPUPlace,
|
||||
ops::EltwiseAddMKLDNNKernel<float>)
|
||||
|
||||
REGISTER_OP_KERNEL(elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace,
|
||||
ops::EltwiseAddMKLDNNGradKernel<float>)
|
@ -0,0 +1,130 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle.fluid.core as core
|
||||
from op_test import OpTest
|
||||
from test_elementwise_add_op import *
|
||||
'''
|
||||
Some tests differ from the tests defined in test_elementwise_add_op.py
|
||||
because MKLDNN does not support tensors of number of dimensions 3.
|
||||
Such dimensions cause exceptions in MKLDNN reorder primitive.
|
||||
'''
|
||||
|
||||
|
||||
class TestMKLDNNElementwiseAddOp(TestElementwiseAddOp):
|
||||
def init_input_output(self):
|
||||
self.x = np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype(self.dtype)
|
||||
self.y = np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype(self.dtype)
|
||||
self.out = np.add(self.x, self.y)
|
||||
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
|
||||
|
||||
class TestMKLDNNElementwiseAddOp_scalar(TestElementwiseAddOp_scalar):
|
||||
def init_input_output(self):
|
||||
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
|
||||
self.y = np.random.rand(1).astype(self.dtype)
|
||||
self.out = self.x + self.y
|
||||
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
|
||||
|
||||
class TestMKLDNNElementwiseAddOp_scalar2(TestElementwiseAddOp_scalar2):
|
||||
def init_input_output(self):
|
||||
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
|
||||
self.y = np.random.rand(1, 1).astype(self.dtype)
|
||||
self.out = self.x + self.y
|
||||
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
|
||||
|
||||
class TestMKLDNNElementwiseAddOp_Vector(TestElementwiseAddOp_Vector):
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
|
||||
|
||||
class TesMKLDNNtElementwiseAddOp_broadcast_0(TestElementwiseAddOp_broadcast_0):
|
||||
def init_input_output(self):
|
||||
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
|
||||
self.y = np.random.rand(2).astype(self.dtype)
|
||||
self.out = self.x + self.y.reshape(2, 1, 1, 1)
|
||||
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
|
||||
|
||||
class TestMKLDNNElementwiseAddOp_broadcast_1(TestElementwiseAddOp_broadcast_1):
|
||||
def init_input_output(self):
|
||||
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
|
||||
self.y = np.random.rand(3).astype(self.dtype)
|
||||
self.out = self.x + self.y.reshape(1, 3, 1, 1)
|
||||
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
|
||||
|
||||
class TestMKLDNNElementwiseAddOp_broadcast_2(TestElementwiseAddOp_broadcast_2):
|
||||
def init_input_output(self):
|
||||
self.x = np.random.rand(2, 2, 3, 4).astype(self.dtype)
|
||||
self.y = np.random.rand(4).astype(self.dtype)
|
||||
self.out = self.x + self.y.reshape(1, 1, 1, 4)
|
||||
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
|
||||
|
||||
class TestMKLDNNElementwiseAddOp_broadcast_3(TestElementwiseAddOp_broadcast_3):
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
|
||||
|
||||
class TestMKLDNNElementwiseAddOp_broadcast_4(TestElementwiseAddOp_broadcast_4):
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
|
||||
|
||||
class TestMKLDNNElementwiseAddOp_rowwise_add_0(
|
||||
TestElementwiseAddOp_rowwise_add_0):
|
||||
def init_input_output(self):
|
||||
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
|
||||
self.y = np.random.rand(3, 4).astype(self.dtype)
|
||||
self.out = self.x + self.y.reshape(1, 3, 4, 1)
|
||||
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
|
||||
|
||||
class TestMKLDNNElementwiseAddOp_rowwise_add_1(
|
||||
TestElementwiseAddOp_rowwise_add_1):
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
|
||||
|
||||
class TestMKLDNNElementwiseAddOp_channelwise_add(
|
||||
TestElementwiseAddOp_channelwise_add):
|
||||
def init_input_output(self):
|
||||
self.x = np.random.rand(3, 5, 20, 20).astype(self.dtype)
|
||||
self.y = np.random.rand(3, 1, 1, 1).astype(self.dtype)
|
||||
self.out = self.x + self.y
|
||||
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue