convert output to nchw format to align with native version in avx512 mode

test = develop
resolve #16764
shanyi15-patch-1
Leo Zhao 6 years ago
parent 85363848a1
commit a9694bd3d6

@ -130,6 +130,13 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format());
// convert to nchw format to align with native version
using platform::MKLDNNDeviceContext;
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
ReorderInput<T>(const_cast<Tensor*>(z), ctx.GetPlace(), mkldnn_engine,
z->dims().size() == 4);
} else {
// Fallback to naive version:
const bool are_inputs_in_same_format = x->format() == y->format();

@ -28,7 +28,8 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp):
self.y = np.random.rand(1, 16).astype(self.dtype)
self.out = x * self.y.reshape(1, 16, 1, 1)
self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
# self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
def setUp(self):
super(TestElementwiseMulMKLDNNOp_BroadcastNCHW16c, self).setUp()

Loading…
Cancel
Save