|
|
|
@ -13,13 +13,14 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <mkldnn/include/mkldnn.hpp>
|
|
|
|
|
#include "paddle/fluid/operators/elementwise_op.h"
|
|
|
|
|
#include "paddle/fluid/operators/elementwise_op_function.h"
|
|
|
|
|
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
|
|
|
|
|
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
|
|
|
|
|
#include "xbyak/xbyak.h"
|
|
|
|
|
#include "xbyak/xbyak_util.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/jit_kernel.h"
|
|
|
|
|
#include "xbyak.h"
|
|
|
|
|
#include "xbyak_util.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -27,47 +28,6 @@ namespace operators {
|
|
|
|
|
using framework::DataLayout;
|
|
|
|
|
using mkldnn::memory;
|
|
|
|
|
|
|
|
|
|
struct vector_mul : public Xbyak::CodeGenerator {
|
|
|
|
|
vector_mul() {
|
|
|
|
|
// RDI is ptr X
|
|
|
|
|
// RSI is ptr Y
|
|
|
|
|
// RDX is ptr Z
|
|
|
|
|
// RCX is h
|
|
|
|
|
// r8 is w
|
|
|
|
|
|
|
|
|
|
push(rbx);
|
|
|
|
|
|
|
|
|
|
xor_(rax, rax);
|
|
|
|
|
xor_(r10, r10);
|
|
|
|
|
vmovups(zmm3, ptr[rsi]);
|
|
|
|
|
|
|
|
|
|
L("h_loop");
|
|
|
|
|
xor_(rbx, rbx);
|
|
|
|
|
L("w_loop");
|
|
|
|
|
vmovups(zmm2, ptr[rdi + rax]);
|
|
|
|
|
vmulps(zmm1, zmm2, zmm3);
|
|
|
|
|
vmovups(ptr[rdx + rax], zmm1);
|
|
|
|
|
add(rax, 64);
|
|
|
|
|
inc(rbx);
|
|
|
|
|
cmp(r8, rbx);
|
|
|
|
|
jnz("w_loop");
|
|
|
|
|
inc(r10);
|
|
|
|
|
cmp(r10, rcx);
|
|
|
|
|
jnz("h_loop");
|
|
|
|
|
|
|
|
|
|
pop(rbx);
|
|
|
|
|
ret();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void check(const float* x, const float* y, float* z, int w) {
|
|
|
|
|
for (int wi = 0; wi < w; wi++) {
|
|
|
|
|
for (int i = 0; i < 16; i++) {
|
|
|
|
|
z[wi * 16 + i] = x[wi * 16 + i] * y[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
|
|
|
|
|
std::transform(format.begin(), format.end(), format.begin(), ::tolower);
|
|
|
|
|
|
|
|
|
@ -163,12 +123,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
constexpr int simd_width = 16;
|
|
|
|
|
int C = c / simd_width;
|
|
|
|
|
|
|
|
|
|
vector_mul mul;
|
|
|
|
|
|
|
|
|
|
using mul_func_t =
|
|
|
|
|
void (*)(const float*, const float*, float*, int, int);
|
|
|
|
|
|
|
|
|
|
mul_func_t mul_func = (mul_func_t)mul.getCode();
|
|
|
|
|
const auto& multiply =
|
|
|
|
|
math::jitkernel::KernelPool::Instance()
|
|
|
|
|
.template Get<math::jitkernel::EltwiseMulnChw16cNCKernel<T>>(n);
|
|
|
|
|
|
|
|
|
|
#pragma omp parallel for collapse(2)
|
|
|
|
|
for (int ni = 0; ni < n; ni++) {
|
|
|
|
@ -180,7 +137,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto ptr_z =
|
|
|
|
|
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
|
|
|
|
|
|
|
|
|
|
mul_func(ptr_x, ptr_y, ptr_z, h, w);
|
|
|
|
|
multiply->Compute(ptr_x, ptr_y, ptr_z, h, w);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|