|
|
|
@ -302,8 +302,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
|
|
|
|
|
int groups = ctx.Attr<int>("groups");
|
|
|
|
|
|
|
|
|
|
// TODO(pzelazko-intel) add support for group convolution and dilation
|
|
|
|
|
PADDLE_ENFORCE(groups == 1, "group convolution is not implemented yet");
|
|
|
|
|
// TODO: add support for dilation
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
|
|
|
|
|
"dilation in convolution is not implemented yet");
|
|
|
|
@ -314,6 +313,19 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
|
|
|
|
|
std::vector<int> weights_tz =
|
|
|
|
|
paddle::framework::vectorize2int(filter->dims());
|
|
|
|
|
int g = std::max(groups, 1);
|
|
|
|
|
if (g > 1) {
|
|
|
|
|
int o = weights_tz[0];
|
|
|
|
|
int i = weights_tz[1];
|
|
|
|
|
int h = weights_tz[2];
|
|
|
|
|
int w = weights_tz[3];
|
|
|
|
|
weights_tz.resize(5);
|
|
|
|
|
weights_tz[0] = g;
|
|
|
|
|
weights_tz[1] = o / g;
|
|
|
|
|
weights_tz[2] = i;
|
|
|
|
|
weights_tz[3] = h;
|
|
|
|
|
weights_tz[4] = w;
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
|
|
|
|
|
|
|
|
|
|
// Get unique name for storing MKLDNN primitives
|
|
|
|
@ -327,7 +339,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto user_src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
|
|
|
|
|
auto user_weights_md = platform::MKLDNNMemDesc(
|
|
|
|
|
{weights_tz}, platform::MKLDNNGetDataType<T>(), filter->format());
|
|
|
|
|
{weights_tz}, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
(g == 1) ? filter->format() : mkldnn::memory::format::goihw);
|
|
|
|
|
|
|
|
|
|
/* create memory descriptor for convolution without specified format
|
|
|
|
|
* ('any') which lets a primitive (convolution in this case) choose
|
|
|
|
@ -340,7 +353,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
auto weights_md = platform::MKLDNNMemDesc(
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
|
|
|
|
|
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
|
|
|
|
|
// Currently used whenever bias is != nullptr.
|
|
|
|
|
auto dst_md = platform::MKLDNNMemDesc(
|
|
|
|
|