|
|
|
@ -37,7 +37,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
"It must use CPUPlace.");
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
ctx.template device_context<platform::MKLDNNDeviceContext>();
|
|
|
|
|
const auto& mkldnn_engine = dev_ctx.GetEngine();
|
|
|
|
|
|
|
|
|
|
const Tensor* input = ctx.Input<Tensor>("X");
|
|
|
|
|
Tensor* output = ctx.Output<Tensor>("Out");
|
|
|
|
@ -66,52 +65,37 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE(input->dims().size() == 4,
|
|
|
|
|
"Input dim must be with 4, i.e. NCHW");
|
|
|
|
|
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto src_tz = paddle::framework::vectorize<int>(input->dims());
|
|
|
|
|
auto dst_tz = paddle::framework::vectorize<int>(output->dims());
|
|
|
|
|
|
|
|
|
|
auto input_format = input->format();
|
|
|
|
|
MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::format_undef};
|
|
|
|
|
|
|
|
|
|
mkldnn::memory::data_type dt =
|
|
|
|
|
paddle::framework::ToMKLDNNDataType(input->type());
|
|
|
|
|
auto fmt = input->format();
|
|
|
|
|
|
|
|
|
|
const std::string key =
|
|
|
|
|
platform::CreateKey(src_tz, pooling_type, ksize, strides, paddings, dt,
|
|
|
|
|
fmt, ctx.op().Output("Out"));
|
|
|
|
|
|
|
|
|
|
platform::PoolingMKLDNNHandler handler(pooling_type, dt,
|
|
|
|
|
ctx.Attr<bool>("is_test"), dev_ctx,
|
|
|
|
|
mkldnn_engine, key);
|
|
|
|
|
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format);
|
|
|
|
|
|
|
|
|
|
auto src_memory =
|
|
|
|
|
handler.AcquireSrcMemory(src_md, to_void_cast<T>(input_data));
|
|
|
|
|
|
|
|
|
|
/* create memory descriptor for pooling without specified format
|
|
|
|
|
* ('any') which lets a primitive (pooling in this case) choose
|
|
|
|
|
* the memory format preferred for best performance
|
|
|
|
|
*/
|
|
|
|
|
auto dst_md = platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any);
|
|
|
|
|
|
|
|
|
|
auto pooling_pd = handler.AcquirePoolingPrimitiveDescriptor(
|
|
|
|
|
src_tz, dst_tz, src_md, dst_md, ksize, strides, paddings,
|
|
|
|
|
ctx.Attr<bool>("ceil_mode"));
|
|
|
|
|
|
|
|
|
|
auto dst_memory =
|
|
|
|
|
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
|
|
|
|
|
|
|
|
|
|
auto pool_p = handler.AcquirePooling(dst_memory, src_memory);
|
|
|
|
|
auto is_test = ctx.Attr<bool>("is_test");
|
|
|
|
|
|
|
|
|
|
platform::PoolingMKLDNNHandler<T> handler(
|
|
|
|
|
src_tz, dst_tz, ksize, strides, paddings, pooling_type,
|
|
|
|
|
ctx.Attr<bool>("ceil_mode"), input->format(),
|
|
|
|
|
paddle::framework::ToMKLDNNDataType(input->type()), is_test, dev_ctx,
|
|
|
|
|
ctx.GetPlace(), ctx.op().Output("Out"));
|
|
|
|
|
|
|
|
|
|
auto src_memory = handler.AcquireSrcMemory(input);
|
|
|
|
|
auto dst_memory = handler.AcquireDstMemory(output);
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::pooling_forward> pool_p;
|
|
|
|
|
std::shared_ptr<mkldnn::memory> workspace_memory;
|
|
|
|
|
if ((is_test == false) && (pooling_type == "max")) {
|
|
|
|
|
// Training
|
|
|
|
|
workspace_memory = handler.AcquireWorkspaceMemory();
|
|
|
|
|
pool_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory,
|
|
|
|
|
*workspace_memory);
|
|
|
|
|
} else {
|
|
|
|
|
// Inference
|
|
|
|
|
pool_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// push primitive to stream and wait until it's executed
|
|
|
|
|
std::vector<mkldnn::primitive> pipeline{*pool_p};
|
|
|
|
|
stream(stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
|
|
|
|
|
output_format =
|
|
|
|
|
auto output_format =
|
|
|
|
|
(MKLDNNMemoryFormat)dst_memory->get_primitive_desc().desc().data.format;
|
|
|
|
|
|
|
|
|
|
output->set_layout(DataLayout::kMKLDNN);
|
|
|
|
@ -158,14 +142,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
ctx.template device_context<platform::MKLDNNDeviceContext>();
|
|
|
|
|
const mkldnn::engine& mkldnn_engine = dev_ctx.GetEngine();
|
|
|
|
|
|
|
|
|
|
std::vector<mkldnn::primitive> pipeline;
|
|
|
|
|
|
|
|
|
|
const T* out_grad_data = out_grad->data<T>();
|
|
|
|
|
T* in_x_grad_data = in_x_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
MKLDNNMemoryFormat in_x_grad_format{MKLDNNMemoryFormat::format_undef};
|
|
|
|
|
|
|
|
|
|
auto diff_src_tz = paddle::framework::vectorize<int>(in_x_grad->dims());
|
|
|
|
|
auto diff_dst_tz = paddle::framework::vectorize<int>(out_grad->dims());
|
|
|
|
|
|
|
|
|
@ -175,34 +154,33 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
diff_src_tz, pooling_type, ksize, strides, paddings,
|
|
|
|
|
memory::data_type::f32, in_x->format(), ctx.op().Input("Out"));
|
|
|
|
|
|
|
|
|
|
platform::PoolingMKLDNNHandler handler(
|
|
|
|
|
pooling_type, paddle::framework::ToMKLDNNDataType(in_x_grad->type()),
|
|
|
|
|
false, dev_ctx, mkldnn_engine, key);
|
|
|
|
|
|
|
|
|
|
auto workspace = handler.AcquireWorkspaceMemory();
|
|
|
|
|
|
|
|
|
|
auto diff_dst_md = platform::MKLDNNMemDesc(
|
|
|
|
|
{diff_dst_tz}, platform::MKLDNNGetDataType<T>(), out_grad->format());
|
|
|
|
|
|
|
|
|
|
auto diff_dst_memory = handler.AcquireDiffDstMemory(
|
|
|
|
|
diff_dst_md, to_void_cast<T>(out_grad_data));
|
|
|
|
|
|
|
|
|
|
auto diff_src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
diff_src_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
|
|
|
|
|
|
|
|
|
|
auto bwd_pd = handler.AcquirePoolingBackwardPrimitiveDescriptor(
|
|
|
|
|
diff_dst_md, diff_src_md, ksize, strides, paddings);
|
|
|
|
|
|
|
|
|
|
auto diff_src_memory = handler.AcquireDiffSrcMemoryFromPrimitive(
|
|
|
|
|
reinterpret_cast<void*>(in_x_grad_data));
|
|
|
|
|
|
|
|
|
|
auto pool_bwd_p = handler.AcquirePoolingBackward(diff_dst_memory, workspace,
|
|
|
|
|
diff_src_memory);
|
|
|
|
|
platform::PoolingMKLDNNHandler<T> handler(
|
|
|
|
|
diff_dst_tz, diff_src_tz, ksize, strides, paddings, pooling_type,
|
|
|
|
|
ctx.Attr<bool>("ceil_mode"), in_x->format(), out_grad->format(),
|
|
|
|
|
paddle::framework::ToMKLDNNDataType(out_grad->type()), dev_ctx,
|
|
|
|
|
ctx.GetPlace(), ctx.op().Input("Out"));
|
|
|
|
|
|
|
|
|
|
auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad);
|
|
|
|
|
auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad);
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::pooling_backward> pool_bwd_p;
|
|
|
|
|
std::shared_ptr<mkldnn::memory> workspace_memory;
|
|
|
|
|
if (pooling_type == "max") {
|
|
|
|
|
// Max - pooling needs Workspace
|
|
|
|
|
workspace_memory = handler.AcquireWorkspaceMemory();
|
|
|
|
|
pool_bwd_p = handler.AcquireBackwardPrimitive(
|
|
|
|
|
*diff_dst_memory, *workspace_memory, *diff_src_memory);
|
|
|
|
|
} else {
|
|
|
|
|
// Average Pooling
|
|
|
|
|
pool_bwd_p =
|
|
|
|
|
handler.AcquireBackwardPrimitive(*diff_dst_memory, *diff_src_memory);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pipeline.push_back(*pool_bwd_p);
|
|
|
|
|
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
|
|
|
|
|
|
|
|
|
in_x_grad_format = (MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc()
|
|
|
|
|
auto in_x_grad_format =
|
|
|
|
|
(MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc()
|
|
|
|
|
.desc()
|
|
|
|
|
.data.format;
|
|
|
|
|
in_x_grad->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|