|
|
|
@ -28,6 +28,46 @@ using mkldnn::stream;
|
|
|
|
|
using platform::to_void_cast;
|
|
|
|
|
using platform::GetMKLDNNFormat;
|
|
|
|
|
|
|
|
|
|
inline void GetWeightsTz(std::vector<int>& weights_tz, int groups, // NOLINT
|
|
|
|
|
bool is_conv3d) {
|
|
|
|
|
if (groups > 1) {
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
int output = weights_tz[0];
|
|
|
|
|
int input = weights_tz[1];
|
|
|
|
|
int dimension = weights_tz[2];
|
|
|
|
|
int height = weights_tz[3];
|
|
|
|
|
int width = weights_tz[4];
|
|
|
|
|
weights_tz.resize(6);
|
|
|
|
|
weights_tz[0] = groups;
|
|
|
|
|
weights_tz[1] = output / groups;
|
|
|
|
|
weights_tz[2] = input;
|
|
|
|
|
weights_tz[3] = dimension;
|
|
|
|
|
weights_tz[4] = height;
|
|
|
|
|
weights_tz[5] = width;
|
|
|
|
|
} else {
|
|
|
|
|
int output = weights_tz[0];
|
|
|
|
|
int input = weights_tz[1];
|
|
|
|
|
int height = weights_tz[2];
|
|
|
|
|
int width = weights_tz[3];
|
|
|
|
|
weights_tz.resize(5);
|
|
|
|
|
weights_tz[0] = groups;
|
|
|
|
|
weights_tz[1] = output / groups;
|
|
|
|
|
weights_tz[2] = input;
|
|
|
|
|
weights_tz[3] = height;
|
|
|
|
|
weights_tz[4] = width;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format,
|
|
|
|
|
int groups, bool is_conv3d) {
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
return (groups == 1) ? format : mkldnn::memory::format::goidhw;
|
|
|
|
|
} else {
|
|
|
|
|
return (groups == 1) ? format : mkldnn::memory::format::goihw;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -53,7 +93,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
filter->format() != memory::format::format_undef,
|
|
|
|
|
"Wrong layout/format set for Filter tensor");
|
|
|
|
|
PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5,
|
|
|
|
|
"Input must be with 4 or 5dimensions, i.e. NCHW or NCDHW");
|
|
|
|
|
"Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
|
|
|
|
|
PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5,
|
|
|
|
|
"Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
|
|
|
|
|
if (bias) {
|
|
|
|
@ -87,33 +127,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
std::vector<int> weights_tz =
|
|
|
|
|
paddle::framework::vectorize2int(filter->dims());
|
|
|
|
|
int g = std::max(groups, 1);
|
|
|
|
|
if (g > 1) {
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
int o = weights_tz[0];
|
|
|
|
|
int i = weights_tz[1];
|
|
|
|
|
int d = weights_tz[2];
|
|
|
|
|
int h = weights_tz[3];
|
|
|
|
|
int w = weights_tz[4];
|
|
|
|
|
weights_tz.resize(6);
|
|
|
|
|
weights_tz[0] = g;
|
|
|
|
|
weights_tz[1] = o / g;
|
|
|
|
|
weights_tz[2] = i;
|
|
|
|
|
weights_tz[3] = d;
|
|
|
|
|
weights_tz[4] = h;
|
|
|
|
|
weights_tz[5] = w;
|
|
|
|
|
} else {
|
|
|
|
|
int o = weights_tz[0];
|
|
|
|
|
int i = weights_tz[1];
|
|
|
|
|
int h = weights_tz[2];
|
|
|
|
|
int w = weights_tz[3];
|
|
|
|
|
weights_tz.resize(5);
|
|
|
|
|
weights_tz[0] = g;
|
|
|
|
|
weights_tz[1] = o / g;
|
|
|
|
|
weights_tz[2] = i;
|
|
|
|
|
weights_tz[3] = h;
|
|
|
|
|
weights_tz[4] = w;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
GetWeightsTz(weights_tz, g, is_conv3d);
|
|
|
|
|
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
|
|
|
|
|
|
|
|
|
|
// Get unique name for storing MKLDNN primitives
|
|
|
|
@ -126,12 +140,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto src_format = input->format();
|
|
|
|
|
mkldnn::memory::format weights_format =
|
|
|
|
|
(g == 1) ? filter->format() : mkldnn::memory::format::goihw;
|
|
|
|
|
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
weights_format =
|
|
|
|
|
(g == 1) ? filter->format() : mkldnn::memory::format::goidhw;
|
|
|
|
|
}
|
|
|
|
|
GetWeightsFormat(filter->format(), g, is_conv3d);
|
|
|
|
|
|
|
|
|
|
auto user_src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
{src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
|
|
|
|
@ -146,15 +155,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto chosen_memory_format =
|
|
|
|
|
platform::data_format_to_memory_format(data_format);
|
|
|
|
|
|
|
|
|
|
weights_format =
|
|
|
|
|
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw;
|
|
|
|
|
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
chosen_memory_format =
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
|
|
|
|
|
weights_format =
|
|
|
|
|
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goidhw;
|
|
|
|
|
}
|
|
|
|
|
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
|
|
|
|
|
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
@ -397,43 +402,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
std::vector<int> weights_tz =
|
|
|
|
|
paddle::framework::vectorize2int(filter->dims());
|
|
|
|
|
int g = std::max(groups, 1);
|
|
|
|
|
if (g > 1) {
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
int o = weights_tz[0];
|
|
|
|
|
int i = weights_tz[1];
|
|
|
|
|
int d = weights_tz[2];
|
|
|
|
|
int h = weights_tz[3];
|
|
|
|
|
int w = weights_tz[4];
|
|
|
|
|
weights_tz.resize(6);
|
|
|
|
|
weights_tz[0] = g;
|
|
|
|
|
weights_tz[1] = o / g;
|
|
|
|
|
weights_tz[2] = i;
|
|
|
|
|
weights_tz[3] = d;
|
|
|
|
|
weights_tz[4] = h;
|
|
|
|
|
weights_tz[5] = w;
|
|
|
|
|
} else {
|
|
|
|
|
int o = weights_tz[0];
|
|
|
|
|
int i = weights_tz[1];
|
|
|
|
|
int h = weights_tz[2];
|
|
|
|
|
int w = weights_tz[3];
|
|
|
|
|
weights_tz.resize(5);
|
|
|
|
|
weights_tz[0] = g;
|
|
|
|
|
weights_tz[1] = o / g;
|
|
|
|
|
weights_tz[2] = i;
|
|
|
|
|
weights_tz[3] = h;
|
|
|
|
|
weights_tz[4] = w;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
GetWeightsTz(weights_tz, g, is_conv3d);
|
|
|
|
|
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
|
|
|
|
|
|
|
|
|
|
auto src_format = input->format();
|
|
|
|
|
mkldnn::memory::format weights_format =
|
|
|
|
|
(g == 1) ? filter->format() : mkldnn::memory::format::goihw;
|
|
|
|
|
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
weights_format =
|
|
|
|
|
(g == 1) ? filter->format() : mkldnn::memory::format::goidhw;
|
|
|
|
|
}
|
|
|
|
|
GetWeightsFormat(filter->format(), g, is_conv3d);
|
|
|
|
|
|
|
|
|
|
// Get an unique name from "argument" name of "Output" variable
|
|
|
|
|
// as well as attributes of primitive to be created
|
|
|
|
@ -461,15 +435,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto chosen_memory_format =
|
|
|
|
|
platform::data_format_to_memory_format(data_format);
|
|
|
|
|
|
|
|
|
|
weights_format =
|
|
|
|
|
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw;
|
|
|
|
|
|
|
|
|
|
if (is_conv3d) {
|
|
|
|
|
chosen_memory_format =
|
|
|
|
|
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
|
|
|
|
|
weights_format =
|
|
|
|
|
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goidhw;
|
|
|
|
|
}
|
|
|
|
|
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
|
|
|
|
|
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|