|
|
|
@ -52,10 +52,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
|
|
|
|
|
filter->format() != memory::format::format_undef,
|
|
|
|
|
"Wrong layout/format set for Filter tensor");
|
|
|
|
|
PADDLE_ENFORCE(input->dims().size() == 4,
|
|
|
|
|
"Input must be with 4 dimensions, i.e. NCHW");
|
|
|
|
|
PADDLE_ENFORCE(filter->dims().size() == 4,
|
|
|
|
|
"Filter must be with 4 dimensions, i.e. OIHW");
|
|
|
|
|
PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5,
|
|
|
|
|
"Input must be with 4 or 5dimensions, i.e. NCHW or NCDHW");
|
|
|
|
|
PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5,
|
|
|
|
|
"Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
|
|
|
|
|
if (bias) {
|
|
|
|
|
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN &&
|
|
|
|
|
bias->format() != memory::format::format_undef,
|
|
|
|
@ -71,9 +71,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
|
|
|
|
|
int groups = ctx.Attr<int>("groups");
|
|
|
|
|
|
|
|
|
|
bool is_conv3d = strides.size() == 3U;
|
|
|
|
|
// TODO(tpatejko): add support for dilation
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
|
|
|
|
|
is_conv3d
|
|
|
|
|
? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 &&
|
|
|
|
|
dilations[2] == 1
|
|
|
|
|
: dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
|
|
|
|
|
"dilation in convolution is not implemented yet");
|
|
|
|
|
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
@ -84,16 +88,31 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
paddle::framework::vectorize2int(filter->dims());
|
|
|
|
|
int g = std::max(groups, 1);
|
|
|
|
|
if (g > 1) {
|
|
|
|
|
int o = weights_tz[0];
|
|
|
|
|
int i = weights_tz[1];
|
|
|
|
|
int h = weights_tz[2];
|
|
|
|
|
int w = weights_tz[3];
|
|
|
|
|
weights_tz.resize(5);
|
|
|
|
|
weights_tz[0] = g;
|
|
|
|
|
weights_tz[1] = o / g;
|
|
|
|
|
weights_tz[2] = i;
|
|
|
|
|
weights_tz[3] = h;
|
|
|
|
|
weights_tz[4] = w;
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
int o = weights_tz[0];
|
|
|
|
|
int i = weights_tz[1];
|
|
|
|
|
int d = weights_tz[2];
|
|
|
|
|
int h = weights_tz[3];
|
|
|
|
|
int w = weights_tz[4];
|
|
|
|
|
weights_tz.resize(6);
|
|
|
|
|
weights_tz[0] = g;
|
|
|
|
|
weights_tz[1] = o / g;
|
|
|
|
|
weights_tz[2] = i;
|
|
|
|
|
weights_tz[3] = d;
|
|
|
|
|
weights_tz[4] = h;
|
|
|
|
|
weights_tz[5] = w;
|
|
|
|
|
} else {
|
|
|
|
|
int o = weights_tz[0];
|
|
|
|
|
int i = weights_tz[1];
|
|
|
|
|
int h = weights_tz[2];
|
|
|
|
|
int w = weights_tz[3];
|
|
|
|
|
weights_tz.resize(5);
|
|
|
|
|
weights_tz[0] = g;
|
|
|
|
|
weights_tz[1] = o / g;
|
|
|
|
|
weights_tz[2] = i;
|
|
|
|
|
weights_tz[3] = h;
|
|
|
|
|
weights_tz[4] = w;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
|
|
|
|
|
|
|
|
|
@ -105,11 +124,19 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
std::vector<primitive> pipeline;
|
|
|
|
|
|
|
|
|
|
auto src_format = input->format();
|
|
|
|
|
mkldnn::memory::format weights_format =
|
|
|
|
|
(g == 1) ? filter->format() : mkldnn::memory::format::goihw;
|
|
|
|
|
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
weights_format =
|
|
|
|
|
(g == 1) ? filter->format() : mkldnn::memory::format::goidhw;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto user_src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
|
|
|
|
|
{src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
|
|
|
|
|
auto user_weights_md = platform::MKLDNNMemDesc(
|
|
|
|
|
{weights_tz}, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
(g == 1) ? filter->format() : mkldnn::memory::format::goihw);
|
|
|
|
|
{weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
|
|
|
|
|
|
|
|
|
|
/* create memory descriptor for convolution without specified format
|
|
|
|
|
* ('any') which lets a primitive (convolution in this case) choose
|
|
|
|
@ -119,10 +146,20 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto chosen_memory_format =
|
|
|
|
|
platform::data_format_to_memory_format(data_format);
|
|
|
|
|
|
|
|
|
|
weights_format =
|
|
|
|
|
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw;
|
|
|
|
|
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
chosen_memory_format =
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
|
|
|
|
|
weights_format =
|
|
|
|
|
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goidhw;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
auto weights_md = platform::MKLDNNMemDesc(
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
|
|
|
|
|
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
|
|
|
|
|
// Currently used whenever bias is != nullptr.
|
|
|
|
|
auto dst_md = platform::MKLDNNMemDesc(
|
|
|
|
@ -263,8 +300,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
const mkldnn::engine& engine, const bool fuse_relu,
|
|
|
|
|
const bool fuse_residual_conn,
|
|
|
|
|
mkldnn::prop_kind fwd_prop_kind) const {
|
|
|
|
|
memory::dims stride_dims = {strides[0], strides[1]};
|
|
|
|
|
memory::dims padding_dims = {paddings[0], paddings[1]};
|
|
|
|
|
memory::dims stride_dims = strides;
|
|
|
|
|
memory::dims padding_dims = paddings;
|
|
|
|
|
|
|
|
|
|
auto conv_desc = mkldnn::convolution_forward::desc(
|
|
|
|
|
fwd_prop_kind, mkldnn::convolution_direct, src, weights, dst,
|
|
|
|
@ -288,8 +325,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
const mkldnn::engine& engine, const bool fuse_relu,
|
|
|
|
|
const bool fuse_residual_conn,
|
|
|
|
|
mkldnn::prop_kind fwd_prop_kind) const {
|
|
|
|
|
memory::dims stride_dims = {strides[0], strides[1]};
|
|
|
|
|
memory::dims padding_dims = {paddings[0], paddings[1]};
|
|
|
|
|
memory::dims stride_dims = strides;
|
|
|
|
|
memory::dims padding_dims = paddings;
|
|
|
|
|
|
|
|
|
|
auto conv_desc = mkldnn::convolution_forward::desc(
|
|
|
|
|
fwd_prop_kind, mkldnn::convolution_direct, src, weights, bias, dst,
|
|
|
|
@ -349,6 +386,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
|
|
|
|
|
int groups = ctx.Attr<int>("groups");
|
|
|
|
|
|
|
|
|
|
bool is_conv3d = strides.size() == 3U;
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
const T* filter_data = filter->data<T>();
|
|
|
|
|
const T* output_grad_data = output_grad->data<T>();
|
|
|
|
@ -358,8 +396,45 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
|
|
|
|
|
std::vector<int> weights_tz =
|
|
|
|
|
paddle::framework::vectorize2int(filter->dims());
|
|
|
|
|
int g = std::max(groups, 1);
|
|
|
|
|
if (g > 1) {
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
int o = weights_tz[0];
|
|
|
|
|
int i = weights_tz[1];
|
|
|
|
|
int d = weights_tz[2];
|
|
|
|
|
int h = weights_tz[3];
|
|
|
|
|
int w = weights_tz[4];
|
|
|
|
|
weights_tz.resize(6);
|
|
|
|
|
weights_tz[0] = g;
|
|
|
|
|
weights_tz[1] = o / g;
|
|
|
|
|
weights_tz[2] = i;
|
|
|
|
|
weights_tz[3] = d;
|
|
|
|
|
weights_tz[4] = h;
|
|
|
|
|
weights_tz[5] = w;
|
|
|
|
|
} else {
|
|
|
|
|
int o = weights_tz[0];
|
|
|
|
|
int i = weights_tz[1];
|
|
|
|
|
int h = weights_tz[2];
|
|
|
|
|
int w = weights_tz[3];
|
|
|
|
|
weights_tz.resize(5);
|
|
|
|
|
weights_tz[0] = g;
|
|
|
|
|
weights_tz[1] = o / g;
|
|
|
|
|
weights_tz[2] = i;
|
|
|
|
|
weights_tz[3] = h;
|
|
|
|
|
weights_tz[4] = w;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
|
|
|
|
|
|
|
|
|
|
auto src_format = input->format();
|
|
|
|
|
mkldnn::memory::format weights_format =
|
|
|
|
|
(g == 1) ? filter->format() : mkldnn::memory::format::goihw;
|
|
|
|
|
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
weights_format =
|
|
|
|
|
(g == 1) ? filter->format() : mkldnn::memory::format::goidhw;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get an unique name from "argument" name of "Output" variable
|
|
|
|
|
// as well as attributes of primitive to be created
|
|
|
|
|
// This name will be used as key when saving info into device context
|
|
|
|
@ -372,9 +447,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// Create user memory descriptors
|
|
|
|
|
auto user_src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
|
|
|
|
|
{src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
|
|
|
|
|
auto user_weights_md = platform::MKLDNNMemDesc(
|
|
|
|
|
{weights_tz}, platform::MKLDNNGetDataType<T>(), filter->format());
|
|
|
|
|
{weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
|
|
|
|
|
auto user_diff_dst_md = platform::MKLDNNMemDesc(
|
|
|
|
|
{dst_tz}, platform::MKLDNNGetDataType<T>(), output_grad->format());
|
|
|
|
|
|
|
|
|
@ -386,14 +461,24 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto chosen_memory_format =
|
|
|
|
|
platform::data_format_to_memory_format(data_format);
|
|
|
|
|
|
|
|
|
|
weights_format =
|
|
|
|
|
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw;
|
|
|
|
|
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
chosen_memory_format =
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
|
|
|
|
|
weights_format =
|
|
|
|
|
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goidhw;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
auto diff_src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
auto weights_md = platform::MKLDNNMemDesc(
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
|
|
|
|
|
auto diff_weights_md = platform::MKLDNNMemDesc(
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
|
|
|
|
|
auto diff_dst_md = platform::MKLDNNMemDesc(
|
|
|
|
|
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
|
|
|
|
@ -496,3 +581,9 @@ REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace,
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
|
|
|
|
|
ops::ConvMKLDNNGradOpKernel<float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_KERNEL(conv3d, MKLDNN, ::paddle::platform::CPUPlace,
|
|
|
|
|
ops::ConvMKLDNNOpKernel<float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_KERNEL(conv3d_grad, MKLDNN, ::paddle::platform::CPUPlace,
|
|
|
|
|
ops::ConvMKLDNNGradOpKernel<float>);
|
|
|
|
|