|
|
|
@ -33,22 +33,7 @@ class ElementwiseMulOp : public ElementwiseOp {
|
|
|
|
|
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
using mkldnn::memory;
|
|
|
|
|
auto CanMKLDNNElementwiseMulBeUsed = [&]() {
|
|
|
|
|
auto x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto y_dims = ctx.Input<Tensor>("Y")->dims();
|
|
|
|
|
int rankdiff = x_dims.size() - y_dims.size();
|
|
|
|
|
// TODO(jczaja): Remove this when oneDNN performance for scalar
|
|
|
|
|
// broadcasting
|
|
|
|
|
// is improved (Ernie large situation)
|
|
|
|
|
if (rankdiff != 0 && y_dims.size() == 1 && y_dims[0] == 1) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (platform::CanMKLDNNBeUsed(ctx) && CanMKLDNNElementwiseMulBeUsed()) {
|
|
|
|
|
if (platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
|
|
|
|
|
framework::DataLayout::kMKLDNN,
|
|
|
|
|
framework::LibraryType::kMKLDNN);
|
|
|
|
|