|
|
|
@ -16,6 +16,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
|
|
|
|
|
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/data_layout_transform.h"
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -53,12 +54,45 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
// Execute default elementwise_add operator when
|
|
|
|
|
// broadcast operations need to performed.
|
|
|
|
|
if (x_dims != y_dims_untrimed) {
|
|
|
|
|
Tensor _x;
|
|
|
|
|
mkldnn::memory::format format;
|
|
|
|
|
std::vector<int> src_x_tz = framework::vectorize2int(x_dims);
|
|
|
|
|
|
|
|
|
|
if ((src_x_tz.size() == 3 &&
|
|
|
|
|
x->format() != (format = memory::format::ncw)) ||
|
|
|
|
|
(src_x_tz.size() == 4 &&
|
|
|
|
|
x->format() != (format = memory::format::nchw)) ||
|
|
|
|
|
(src_x_tz.size() == 5 &&
|
|
|
|
|
x->format() != (format = memory::format::ncdhw))) {
|
|
|
|
|
_x.Resize(x_dims);
|
|
|
|
|
auto user_x_memory_pd = memory::primitive_desc(
|
|
|
|
|
{{src_x_tz}, memory::data_type::f32, x->format()}, mkldnn_engine);
|
|
|
|
|
auto x_memory_pd = memory::primitive_desc(
|
|
|
|
|
{{src_x_tz}, memory::data_type::f32, format}, mkldnn_engine);
|
|
|
|
|
auto size = x_memory_pd.get_size();
|
|
|
|
|
_x.mutable_data<T>(ctx.GetPlace(), paddle::memory::Allocator::kDefault,
|
|
|
|
|
size);
|
|
|
|
|
auto user_x_memory =
|
|
|
|
|
memory(user_x_memory_pd, paddle::platform::to_void_cast<T>(x_data));
|
|
|
|
|
auto x_memory = memory(x_memory_pd,
|
|
|
|
|
paddle::platform::to_void_cast<T>(_x.data<T>()));
|
|
|
|
|
|
|
|
|
|
auto x_reorder = reorder(user_x_memory, x_memory);
|
|
|
|
|
|
|
|
|
|
std::vector<primitive> pipeline;
|
|
|
|
|
pipeline.push_back(x_reorder);
|
|
|
|
|
stream(stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
} else {
|
|
|
|
|
format = x->format();
|
|
|
|
|
_x.ShareDataWith(*x);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto sum_func = [](T a, T b) -> T { return a + b; };
|
|
|
|
|
|
|
|
|
|
TransformFunctor<decltype(sum_func), T,
|
|
|
|
|
paddle::platform::CPUDeviceContext, T>
|
|
|
|
|
functor(
|
|
|
|
|
x, y, z,
|
|
|
|
|
&_x, y, z,
|
|
|
|
|
ctx.template device_context<paddle::platform::CPUDeviceContext>(),
|
|
|
|
|
sum_func);
|
|
|
|
|
|
|
|
|
@ -78,7 +112,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
functor.RunMidWise(n, pre, post);
|
|
|
|
|
}
|
|
|
|
|
z->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
z->set_format(x->format());
|
|
|
|
|
z->set_format(format);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
|
|
|
|
|
x->format() != memory::format::format_undef,
|
|
|
|
|