|
|
|
|
@ -95,6 +95,26 @@ static void UpdateDataFormat(const framework::ExecutionContext& ctx,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static void ReorderInput(framework::Tensor* tensor,
|
|
|
|
|
const platform::Place& place,
|
|
|
|
|
const mkldnn::engine& engine,
|
|
|
|
|
bool isFourDim) {
|
|
|
|
|
using platform::to_void_cast;
|
|
|
|
|
auto dims = paddle::framework::vectorize2int(tensor->dims());
|
|
|
|
|
framework::Tensor out_tensor;
|
|
|
|
|
out_tensor.Resize(tensor->dims());
|
|
|
|
|
out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc);
|
|
|
|
|
out_tensor.set_layout(tensor->layout());
|
|
|
|
|
mkldnn::memory input_memory = {{{dims, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
tensor->format()}, engine}, to_void_cast<T>(tensor->data<T>())};
|
|
|
|
|
mkldnn::memory output_memory = {{{dims, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
out_tensor.format()}, engine},
|
|
|
|
|
to_void_cast<T>(out_tensor.mutable_data<T>(place))};
|
|
|
|
|
platform::Reorder(input_memory, output_memory);
|
|
|
|
|
tensor->ShareDataWith(out_tensor);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
@ -111,12 +131,15 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
|
auto y_dims_untrimmed = y->dims();
|
|
|
|
|
auto x_int_dims = paddle::framework::vectorize2int(x_dims);
|
|
|
|
|
|
|
|
|
|
UpdateDataFormat(ctx, (Tensor*)x, "x_data_format");
|
|
|
|
|
UpdateDataFormat(ctx, (Tensor*)y, "y_data_format");
|
|
|
|
|
|
|
|
|
|
if (x->format() == memory::format::nChw16c && y->format() == memory::format::nc) {
|
|
|
|
|
if (x_dims != y_dims_untrimmed) {
|
|
|
|
|
const bool are_dims_divisable = !(x_int_dims[1] % 16);
|
|
|
|
|
const bool is_x_format_correct = x->format() == memory::format::nChw16c;
|
|
|
|
|
const bool is_y_format_correct = y->format() == memory::format::nc;
|
|
|
|
|
if (is_x_format_correct && is_y_format_correct && are_dims_divisable) {
|
|
|
|
|
int pre, n, post;
|
|
|
|
|
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
|
|
|
|
|
|
|
|
|
|
@ -163,11 +186,23 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
z->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
z->set_format(x->format());
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Not implemented when dims are equal");
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// Fallback to naive version:
|
|
|
|
|
const bool are_inputs_in_same_format = x->format() == y->format();
|
|
|
|
|
const bool is_x_nchw= x->format() == memory::format::nchw;
|
|
|
|
|
const bool is_x_nc = x->format() == memory::format::nc;
|
|
|
|
|
const bool is_y_nchw= y->format() == memory::format::nchw;
|
|
|
|
|
const bool is_y_nc = y->format() == memory::format::nc;
|
|
|
|
|
if(!are_inputs_in_same_format) {
|
|
|
|
|
using platform::MKLDNNDeviceContext;
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
|
|
|
|
const auto& mkldnn_engine = dev_ctx.GetEngine();
|
|
|
|
|
if(!(is_x_nchw || is_x_nc))
|
|
|
|
|
ReorderInput<T>((Tensor*)x, ctx.GetPlace(), mkldnn_engine, x->dims().size() == 4);
|
|
|
|
|
if(!(is_y_nchw || is_y_nc))
|
|
|
|
|
ReorderInput<T>((Tensor*)y, ctx.GetPlace(), mkldnn_engine, y->dims().size() == 4);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto mul_func = [](T a, T b) -> T { return a * b; };
|
|
|
|
|
|
|
|
|
|
TransformFunctor<decltype(mul_func), T,
|
|
|
|
|
|