|
|
|
@ -300,7 +300,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
|
|
|
|
|
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
|
|
|
|
|
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
|
|
|
|
|
bool fuse_eltwise = ctx.Attr<bool>("fuse_eltwise");
|
|
|
|
|
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
|
|
|
|
|
int groups = ctx.Attr<int>("groups");
|
|
|
|
|
|
|
|
|
|
// TODO(tpatejko): add support for dilation
|
|
|
|
@ -369,11 +369,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
|
|
|
|
|
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
|
|
|
|
|
strides, paddings, mkldnn_engine,
|
|
|
|
|
fuse_relu, fuse_eltwise);
|
|
|
|
|
fuse_relu, fuse_residual_conn);
|
|
|
|
|
} else {
|
|
|
|
|
conv_pd =
|
|
|
|
|
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
|
|
|
|
|
mkldnn_engine, fuse_relu, fuse_eltwise);
|
|
|
|
|
mkldnn_engine, fuse_relu, fuse_residual_conn);
|
|
|
|
|
}
|
|
|
|
|
// Save conv_pd/src_memory/weights_memory for backward pass
|
|
|
|
|
dev_ctx.SetBlob(key_conv_pd, conv_pd);
|
|
|
|
@ -388,7 +388,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
T* output_data = nullptr;
|
|
|
|
|
|
|
|
|
|
if (fuse_eltwise) {
|
|
|
|
|
if (fuse_residual_conn) {
|
|
|
|
|
auto residual_param = ctx.Input<Tensor>("ResidualData");
|
|
|
|
|
auto residual_param_data = residual_param->data<T>();
|
|
|
|
|
|
|
|
|
@ -442,14 +442,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
|
|
|
|
|
bool fuse_eltwise) const {
|
|
|
|
|
bool fuse_residual_conn) const {
|
|
|
|
|
mkldnn::primitive_attr conv_attr;
|
|
|
|
|
mkldnn::post_ops post_operations;
|
|
|
|
|
// Fusion with Elementwise layer relies on adding a sum post-operation with
|
|
|
|
|
// the scale parameter. It is assumed that when fuse_eltwise is true, the
|
|
|
|
|
// Output tensor contains the data coming from residual connection. The
|
|
|
|
|
// result of this post_op is: Output = scale * Output + Conv_Out.
|
|
|
|
|
if (fuse_eltwise) {
|
|
|
|
|
// the scale parameter. It is assumed that when fuse_residual_connection is
|
|
|
|
|
// true, the output tensor contains the data coming from residual
|
|
|
|
|
// connection. The result of this post_op is:
|
|
|
|
|
// Output = scale * Output + Conv_Out.
|
|
|
|
|
if (fuse_residual_conn) {
|
|
|
|
|
post_operations.append_sum(1.0f);
|
|
|
|
|
}
|
|
|
|
|
// Fusion with ReLU layer is executed through the PostOps feature. Create a
|
|
|
|
@ -470,7 +471,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
const memory::desc& dst, const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings,
|
|
|
|
|
const mkldnn::engine& engine, const bool fuse_relu,
|
|
|
|
|
const bool fuse_eltwise) const {
|
|
|
|
|
const bool fuse_residual_conn) const {
|
|
|
|
|
memory::dims stride_dims = {strides[0], strides[1]};
|
|
|
|
|
memory::dims padding_dims = {paddings[0], paddings[1]};
|
|
|
|
|
|
|
|
|
@ -479,7 +480,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
dst, stride_dims, padding_dims, padding_dims,
|
|
|
|
|
mkldnn::padding_kind::zero);
|
|
|
|
|
|
|
|
|
|
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
|
|
|
|
|
mkldnn::primitive_attr conv_attr =
|
|
|
|
|
CreatePostOps(fuse_relu, fuse_residual_conn);
|
|
|
|
|
|
|
|
|
|
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
|
|
|
|
|
conv_desc, conv_attr, engine);
|
|
|
|
@ -494,7 +496,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings,
|
|
|
|
|
const mkldnn::engine& engine, const bool fuse_relu,
|
|
|
|
|
const bool fuse_eltwise) const {
|
|
|
|
|
const bool fuse_residual_conn) const {
|
|
|
|
|
memory::dims stride_dims = {strides[0], strides[1]};
|
|
|
|
|
memory::dims padding_dims = {paddings[0], paddings[1]};
|
|
|
|
|
|
|
|
|
@ -503,7 +505,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
bias, dst, stride_dims, padding_dims, padding_dims,
|
|
|
|
|
mkldnn::padding_kind::zero);
|
|
|
|
|
|
|
|
|
|
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
|
|
|
|
|
mkldnn::primitive_attr conv_attr =
|
|
|
|
|
CreatePostOps(fuse_relu, fuse_residual_conn);
|
|
|
|
|
|
|
|
|
|
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
|
|
|
|
|
conv_desc, conv_attr, engine);
|
|
|
|
|