- ReImplemented pooling fwd mkldnn (#19911)

- First implementation of BWD and FWD of pooling mkl-dnn

- Compilation fix

- Fix

- Fix

 - Fix

- Fix to crash

- Compilation fix

- Combined AcquireBacward with Fwd

test=develop
expand_as_op_1
Jacek Czaja 5 years ago committed by Tao Luo
parent 790d5226b5
commit 5b07ca9cdd

@ -37,7 +37,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"It must use CPUPlace.");
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out");
@ -66,52 +65,37 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(input->dims().size() == 4,
"Input dim must be with 4, i.e. NCHW");
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto src_tz = paddle::framework::vectorize<int>(input->dims());
auto dst_tz = paddle::framework::vectorize<int>(output->dims());
auto input_format = input->format();
MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::format_undef};
mkldnn::memory::data_type dt =
paddle::framework::ToMKLDNNDataType(input->type());
auto fmt = input->format();
const std::string key =
platform::CreateKey(src_tz, pooling_type, ksize, strides, paddings, dt,
fmt, ctx.op().Output("Out"));
platform::PoolingMKLDNNHandler handler(pooling_type, dt,
ctx.Attr<bool>("is_test"), dev_ctx,
mkldnn_engine, key);
auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format);
auto src_memory =
handler.AcquireSrcMemory(src_md, to_void_cast<T>(input_data));
/* create memory descriptor for pooling without specified format
* ('any') which lets a primitive (pooling in this case) choose
* the memory format preferred for best performance
*/
auto dst_md = platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any);
auto pooling_pd = handler.AcquirePoolingPrimitiveDescriptor(
src_tz, dst_tz, src_md, dst_md, ksize, strides, paddings,
ctx.Attr<bool>("ceil_mode"));
auto dst_memory =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
auto pool_p = handler.AcquirePooling(dst_memory, src_memory);
auto is_test = ctx.Attr<bool>("is_test");
platform::PoolingMKLDNNHandler<T> handler(
src_tz, dst_tz, ksize, strides, paddings, pooling_type,
ctx.Attr<bool>("ceil_mode"), input->format(),
paddle::framework::ToMKLDNNDataType(input->type()), is_test, dev_ctx,
ctx.GetPlace(), ctx.op().Output("Out"));
auto src_memory = handler.AcquireSrcMemory(input);
auto dst_memory = handler.AcquireDstMemory(output);
std::shared_ptr<mkldnn::pooling_forward> pool_p;
std::shared_ptr<mkldnn::memory> workspace_memory;
if ((is_test == false) && (pooling_type == "max")) {
// Training
workspace_memory = handler.AcquireWorkspaceMemory();
pool_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory,
*workspace_memory);
} else {
// Inference
pool_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory);
}
// push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline{*pool_p};
stream(stream::kind::eager).submit(pipeline).wait();
output_format =
auto output_format =
(MKLDNNMemoryFormat)dst_memory->get_primitive_desc().desc().data.format;
output->set_layout(DataLayout::kMKLDNN);
@ -158,14 +142,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const mkldnn::engine& mkldnn_engine = dev_ctx.GetEngine();
std::vector<mkldnn::primitive> pipeline;
const T* out_grad_data = out_grad->data<T>();
T* in_x_grad_data = in_x_grad->mutable_data<T>(ctx.GetPlace());
MKLDNNMemoryFormat in_x_grad_format{MKLDNNMemoryFormat::format_undef};
auto diff_src_tz = paddle::framework::vectorize<int>(in_x_grad->dims());
auto diff_dst_tz = paddle::framework::vectorize<int>(out_grad->dims());
@ -175,36 +154,35 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
diff_src_tz, pooling_type, ksize, strides, paddings,
memory::data_type::f32, in_x->format(), ctx.op().Input("Out"));
platform::PoolingMKLDNNHandler handler(
pooling_type, paddle::framework::ToMKLDNNDataType(in_x_grad->type()),
false, dev_ctx, mkldnn_engine, key);
auto workspace = handler.AcquireWorkspaceMemory();
auto diff_dst_md = platform::MKLDNNMemDesc(
{diff_dst_tz}, platform::MKLDNNGetDataType<T>(), out_grad->format());
auto diff_dst_memory = handler.AcquireDiffDstMemory(
diff_dst_md, to_void_cast<T>(out_grad_data));
auto diff_src_md = platform::MKLDNNMemDesc(
diff_src_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
auto bwd_pd = handler.AcquirePoolingBackwardPrimitiveDescriptor(
diff_dst_md, diff_src_md, ksize, strides, paddings);
auto diff_src_memory = handler.AcquireDiffSrcMemoryFromPrimitive(
reinterpret_cast<void*>(in_x_grad_data));
auto pool_bwd_p = handler.AcquirePoolingBackward(diff_dst_memory, workspace,
diff_src_memory);
platform::PoolingMKLDNNHandler<T> handler(
diff_dst_tz, diff_src_tz, ksize, strides, paddings, pooling_type,
ctx.Attr<bool>("ceil_mode"), in_x->format(), out_grad->format(),
paddle::framework::ToMKLDNNDataType(out_grad->type()), dev_ctx,
ctx.GetPlace(), ctx.op().Input("Out"));
auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad);
auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad);
std::shared_ptr<mkldnn::pooling_backward> pool_bwd_p;
std::shared_ptr<mkldnn::memory> workspace_memory;
if (pooling_type == "max") {
// Max - pooling needs Workspace
workspace_memory = handler.AcquireWorkspaceMemory();
pool_bwd_p = handler.AcquireBackwardPrimitive(
*diff_dst_memory, *workspace_memory, *diff_src_memory);
} else {
// Average Pooling
pool_bwd_p =
handler.AcquireBackwardPrimitive(*diff_dst_memory, *diff_src_memory);
}
pipeline.push_back(*pool_bwd_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
in_x_grad_format = (MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc()
.desc()
.data.format;
auto in_x_grad_format =
(MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc()
.desc()
.data.format;
in_x_grad->set_layout(DataLayout::kMKLDNN);
in_x_grad->set_format(in_x_grad_format);
} // Compute()

@ -66,8 +66,6 @@ class SoftmaxMKLDNNHandler
auto diff_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring,
data_softmax_md, axis);
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
axis);
}

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save