|
|
|
@ -136,10 +136,13 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
UpdateDataFormat(ctx, (Tensor*)x, "x_data_format");
|
|
|
|
|
UpdateDataFormat(ctx, (Tensor*)y, "y_data_format");
|
|
|
|
|
|
|
|
|
|
Xbyak::util::Cpu cpu;
|
|
|
|
|
const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F);
|
|
|
|
|
const bool are_dims_divisable = !(x_int_dims[1] % 16);
|
|
|
|
|
const bool is_x_format_correct = x->format() == memory::format::nChw16c;
|
|
|
|
|
const bool is_y_format_correct = y->format() == memory::format::nc;
|
|
|
|
|
if (is_x_format_correct && is_y_format_correct && are_dims_divisable) {
|
|
|
|
|
if (is_x_format_correct && is_y_format_correct && are_dims_divisable &&
|
|
|
|
|
is_avx512_enabled) {
|
|
|
|
|
int pre, n, post;
|
|
|
|
|
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
|
|
|
|
|
|
|
|
|
|