|
|
@ -17,11 +17,29 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "xbyak/xbyak.h"
|
|
|
|
|
|
|
|
#include "xbyak/xbyak_util.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
|
|
using framework::DataLayout;
|
|
|
|
using framework::DataLayout;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct vector_mul : public Xbyak::CodeGenerator {
|
|
|
|
|
|
|
|
vector_mul() {
|
|
|
|
|
|
|
|
// RDI is ptr X
|
|
|
|
|
|
|
|
// RSI is ptr Y
|
|
|
|
|
|
|
|
// RDX is ptr Z
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vmovups(zmm2, ptr[rdi]);
|
|
|
|
|
|
|
|
vmovups(zmm3, ptr[rsi]);
|
|
|
|
|
|
|
|
vmulps(zmm1, zmm2, zmm3);
|
|
|
|
|
|
|
|
vmovups(ptr[rdx], zmm1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ret();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
@ -61,6 +79,14 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
constexpr int simd_width = 16;
|
|
|
|
constexpr int simd_width = 16;
|
|
|
|
int C = c / simd_width;
|
|
|
|
int C = c / simd_width;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vector_mul mul;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
using mul_func_t = void (*)(const float*, const float*, float*);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mul_func_t mul_func = (mul_func_t)mul.getCode();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto ptr_x = x_data;
|
|
|
|
|
|
|
|
|
|
|
|
for (int ni = 0; ni < n; ni++) {
|
|
|
|
for (int ni = 0; ni < n; ni++) {
|
|
|
|
for (int ci = 0; ci < C; ci++) {
|
|
|
|
for (int ci = 0; ci < C; ci++) {
|
|
|
|
for (int hi = 0; hi < h; hi++) {
|
|
|
|
for (int hi = 0; hi < h; hi++) {
|
|
|
@ -74,9 +100,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
ci * h * w * simd_width + hi * w * simd_width +
|
|
|
|
ci * h * w * simd_width + hi * w * simd_width +
|
|
|
|
wi * simd_width;
|
|
|
|
wi * simd_width;
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < simd_width; i++) {
|
|
|
|
mul_func(ptr_x, ptr_y, ptr_z);
|
|
|
|
ptr_z[i] = ptr_x[i] * ptr_y[i];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|