|
|
|
@ -14,10 +14,13 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/jit_kernel.h"
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/operators/math/jit_code.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_XBYAK
|
|
|
|
|
#include "paddle/fluid/operators/math/jit_code.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
#include "paddle/fluid/platform/dynload/mklml.h"
|
|
|
|
|
#endif
|
|
|
|
@ -64,6 +67,7 @@ class VMulKernelImpl : public VMulKernel<T> {
|
|
|
|
|
static inline bool useMKL(int d) { return false; }
|
|
|
|
|
|
|
|
|
|
explicit VMulKernelImpl(int d) : VMulKernel<T>() {
|
|
|
|
|
#ifdef PADDLE_WITH_XBYAK
|
|
|
|
|
if (useJIT(d)) {
|
|
|
|
|
// roughly estimate the size of code
|
|
|
|
|
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
|
|
|
|
@ -72,6 +76,7 @@ class VMulKernelImpl : public VMulKernel<T> {
|
|
|
|
|
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
if (useMKL(d)) {
|
|
|
|
|
this->Compute = VMulMKL<T>;
|
|
|
|
@ -81,15 +86,21 @@ class VMulKernelImpl : public VMulKernel<T> {
|
|
|
|
|
this->Compute = VMulRefer<T>;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_XBYAK
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::unique_ptr<gen::VMulJitCode> jitcode_{nullptr};
|
|
|
|
|
#endif
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_XBYAK
|
|
|
|
|
template <>
|
|
|
|
|
bool VMulKernelImpl<float>::useJIT(int d) {
|
|
|
|
|
return gen::VMulJitCode::init(d);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
template <>
|
|
|
|
|
bool VMulKernelImpl<float>::useMKL(int d) {
|
|
|
|
|
return jit::MayIUse(jit::avx512f) && d > 512;
|
|
|
|
@ -99,6 +110,7 @@ template <>
|
|
|
|
|
bool VMulKernelImpl<double>::useMKL(int d) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
REGISTER_JITKERNEL(vmul, VMulKernel);
|
|
|
|
|
|
|
|
|
|