|
|
|
@ -17,7 +17,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/data_layout_transform.h"
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_reuse.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -65,21 +65,27 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
(src_x_tz.size() == 5 &&
|
|
|
|
|
x->format() != (format = memory::format::ncdhw))) {
|
|
|
|
|
_x.Resize(x_dims);
|
|
|
|
|
auto user_x_memory_pd = memory::primitive_desc(
|
|
|
|
|
{{src_x_tz}, memory::data_type::f32, x->format()}, mkldnn_engine);
|
|
|
|
|
auto x_memory_pd = memory::primitive_desc(
|
|
|
|
|
{{src_x_tz}, memory::data_type::f32, format}, mkldnn_engine);
|
|
|
|
|
auto size = x_memory_pd.get_size();
|
|
|
|
|
_x.mutable_data<T>(ctx.GetPlace(), size);
|
|
|
|
|
auto user_x_memory =
|
|
|
|
|
memory(user_x_memory_pd, paddle::platform::to_void_cast<T>(x_data));
|
|
|
|
|
auto x_memory = memory(x_memory_pd,
|
|
|
|
|
paddle::platform::to_void_cast<T>(_x.data<T>()));
|
|
|
|
|
|
|
|
|
|
auto x_reorder = reorder(user_x_memory, x_memory);
|
|
|
|
|
|
|
|
|
|
mkldnn::memory::data_type in_type = platform::MKLDNNGetDataType<T>();
|
|
|
|
|
auto out_format = platform::MKLDNNFormatForSize(
|
|
|
|
|
x_dims.size(), mkldnn::memory::format::nchw);
|
|
|
|
|
|
|
|
|
|
const std::string key = platform::ReorderMKLDNNHandler::GetHash(
|
|
|
|
|
src_x_tz, x->format(), out_format, std::to_string(in_type));
|
|
|
|
|
|
|
|
|
|
platform::ReorderMKLDNNHandler handler(src_x_tz, x->type(), in_type,
|
|
|
|
|
dev_ctx, mkldnn_engine, key);
|
|
|
|
|
|
|
|
|
|
auto user_x_memory_p = handler.AcquireSrcMemory(
|
|
|
|
|
x->format(), paddle::platform::to_void_cast(x_data));
|
|
|
|
|
|
|
|
|
|
auto x_memory_p =
|
|
|
|
|
handler.AcquireDstMemory(&_x, out_format, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto x_reorder = handler.AcquireReorder(x_memory_p, user_x_memory_p);
|
|
|
|
|
|
|
|
|
|
std::vector<primitive> pipeline;
|
|
|
|
|
pipeline.push_back(x_reorder);
|
|
|
|
|
pipeline.push_back(*x_reorder);
|
|
|
|
|
stream(stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
} else {
|
|
|
|
|
format = x->format();
|
|
|
|
@ -125,46 +131,41 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
std::vector<int> dst_tz = framework::vectorize2int(z_dims);
|
|
|
|
|
|
|
|
|
|
std::vector<memory::primitive_desc> srcs_pd;
|
|
|
|
|
std::vector<memory> srcs;
|
|
|
|
|
std::vector<float> scales = {1.0f, 1.0f};
|
|
|
|
|
|
|
|
|
|
auto src_x_pd = memory::primitive_desc(
|
|
|
|
|
{{src_x_tz}, memory::data_type::f32, x->format()}, mkldnn_engine);
|
|
|
|
|
auto src_y_pd = memory::primitive_desc(
|
|
|
|
|
{{src_y_tz}, memory::data_type::f32, y->format()}, mkldnn_engine);
|
|
|
|
|
auto src_x_memory =
|
|
|
|
|
memory(src_x_pd, paddle::platform::to_void_cast(x_data));
|
|
|
|
|
auto src_y_memory =
|
|
|
|
|
memory(src_y_pd, paddle::platform::to_void_cast(y_data));
|
|
|
|
|
const std::string key = platform::MKLDNNHandler::GetHash(
|
|
|
|
|
src_x_tz, ctx.op().Output("Out") + std::to_string(x->format()) +
|
|
|
|
|
std::to_string(y->format()));
|
|
|
|
|
|
|
|
|
|
platform::SumMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
|
|
|
|
|
|
|
|
|
|
auto src_x_memory = handler.AcquireSrcMemory(
|
|
|
|
|
{{src_x_tz}, platform::MKLDNNGetDataType<T>(), x->format()},
|
|
|
|
|
paddle::platform::to_void_cast(x_data));
|
|
|
|
|
|
|
|
|
|
srcs_pd.push_back(src_x_pd);
|
|
|
|
|
srcs_pd.push_back(src_y_pd);
|
|
|
|
|
srcs.push_back(src_x_memory);
|
|
|
|
|
srcs.push_back(src_y_memory);
|
|
|
|
|
auto src_y_memory = handler.AcquireSecondSrcMemory(
|
|
|
|
|
{{src_y_tz}, platform::MKLDNNGetDataType<T>(), y->format()},
|
|
|
|
|
paddle::platform::to_void_cast(y_data));
|
|
|
|
|
|
|
|
|
|
auto dst_md =
|
|
|
|
|
memory::desc({dst_tz}, memory::data_type::f32, memory::format::any);
|
|
|
|
|
auto dst_md = memory::desc({dst_tz}, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
memory::format::any);
|
|
|
|
|
|
|
|
|
|
// create primitive descriptor for sum
|
|
|
|
|
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd);
|
|
|
|
|
auto sum_pd = handler.AcquireSumPrimitiveDescriptor(
|
|
|
|
|
{src_x_memory, src_y_memory}, scales, dst_md);
|
|
|
|
|
|
|
|
|
|
// create mkldnn memory for dst
|
|
|
|
|
memory dst_memory = memory(sum_pd.dst_primitive_desc(), z_data);
|
|
|
|
|
auto dst_memory = handler.AcquireDstMemoryFromPrimitive(z_data);
|
|
|
|
|
|
|
|
|
|
std::vector<primitive::at> inputs;
|
|
|
|
|
inputs.push_back(srcs[0]);
|
|
|
|
|
inputs.push_back(srcs[1]);
|
|
|
|
|
std::vector<primitive::at> inputs({*src_x_memory, *src_y_memory});
|
|
|
|
|
|
|
|
|
|
// create sum primitive
|
|
|
|
|
auto sum_prim = sum(sum_pd, inputs, dst_memory);
|
|
|
|
|
auto sum_prim = handler.AcquireSum(dst_memory, &inputs);
|
|
|
|
|
|
|
|
|
|
std::vector<primitive> pipeline;
|
|
|
|
|
pipeline.push_back(sum_prim);
|
|
|
|
|
pipeline.push_back(*sum_prim);
|
|
|
|
|
stream(stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
|
|
|
|
|
z->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
z->set_format(
|
|
|
|
|
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
|
|
|
|
|
(memory::format)dst_memory->get_primitive_desc().desc().data.format);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|