|
|
|
@ -19,36 +19,21 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/jit_kernel.h"
|
|
|
|
|
#include "xbyak.h"
|
|
|
|
|
#include "xbyak_util.h"
|
|
|
|
|
#include "xbyak/xbyak.h"
|
|
|
|
|
#include "xbyak/xbyak_util.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using framework::DataLayout;
|
|
|
|
|
using mkldnn::memory;
|
|
|
|
|
|
|
|
|
|
static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
|
|
|
|
|
std::transform(format.begin(), format.end(), format.begin(), ::tolower);
|
|
|
|
|
|
|
|
|
|
if (!format.compare("nchw")) {
|
|
|
|
|
return memory::format::nchw;
|
|
|
|
|
} else if (!format.compare("nchw16c")) {
|
|
|
|
|
return memory::format::nChw16c;
|
|
|
|
|
} else if (!format.compare("nchw8c")) {
|
|
|
|
|
return memory::format::nChw8c;
|
|
|
|
|
} else if (!format.compare("nhwc")) {
|
|
|
|
|
return memory::format::nhwc;
|
|
|
|
|
} else {
|
|
|
|
|
return memory::format::any;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
using platform::StringToMKLDNNFormat;
|
|
|
|
|
|
|
|
|
|
static void UpdateDataFormat(const framework::ExecutionContext& ctx,
|
|
|
|
|
framework::Tensor* tensor, const char* attribute) {
|
|
|
|
|
if (ctx.op().HasAttr(attribute)) {
|
|
|
|
|
auto format_as_string = ctx.Attr<std::string>(attribute);
|
|
|
|
|
auto format = StringToMKLDNNFormat(format_as_string);
|
|
|
|
|
auto format = StringToMKLDNNFormat(&format_as_string);
|
|
|
|
|
if (format != memory::format::any) {
|
|
|
|
|
tensor->set_format(format);
|
|
|
|
|
}
|
|
|
|
@ -93,8 +78,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto y_dims_untrimmed = y->dims();
|
|
|
|
|
auto x_int_dims = paddle::framework::vectorize2int(x_dims);
|
|
|
|
|
|
|
|
|
|
UpdateDataFormat(ctx, (Tensor*)x, "x_data_format");
|
|
|
|
|
UpdateDataFormat(ctx, (Tensor*)y, "y_data_format");
|
|
|
|
|
UpdateDataFormat(ctx, const_cast<Tensor*>(x), "x_data_format");
|
|
|
|
|
UpdateDataFormat(ctx, const_cast<Tensor*>(y), "y_data_format");
|
|
|
|
|
|
|
|
|
|
Xbyak::util::Cpu cpu;
|
|
|
|
|
const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F);
|
|
|
|
@ -156,10 +141,10 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
|
|
|
|
const auto& mkldnn_engine = dev_ctx.GetEngine();
|
|
|
|
|
if (!(is_x_nchw || is_x_nc))
|
|
|
|
|
ReorderInput<T>((Tensor*)x, ctx.GetPlace(), mkldnn_engine,
|
|
|
|
|
ReorderInput<T>(const_cast<Tensor*>(x), ctx.GetPlace(), mkldnn_engine,
|
|
|
|
|
x->dims().size() == 4);
|
|
|
|
|
if (!(is_y_nchw || is_y_nc))
|
|
|
|
|
ReorderInput<T>((Tensor*)y, ctx.GetPlace(), mkldnn_engine,
|
|
|
|
|
ReorderInput<T>(const_cast<Tensor*>(y), ctx.GetPlace(), mkldnn_engine,
|
|
|
|
|
y->dims().size() == 4);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|