|
|
|
@ -16,6 +16,8 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/operators/math/jit_gen.h"
|
|
|
|
|
#include "paddle/fluid/platform/cpu_info.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
namespace math {
|
|
|
|
@ -40,6 +42,51 @@ typedef enum {
|
|
|
|
|
identity
|
|
|
|
|
} operand_type;
|
|
|
|
|
|
|
|
|
|
extern const float exp_float_consts[];
|
|
|
|
|
extern const int exp_int_0x7f[];
|
|
|
|
|
extern int g_tmp_mem[];
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): move these to some proper place
|
|
|
|
|
#define SIGMOID_THRESHOLD_MIN -40.0
|
|
|
|
|
#define SIGMOID_THRESHOLD_MAX 13.0
|
|
|
|
|
#define EXP_MAX_INPUT 40.0
|
|
|
|
|
#define XMM_FLOAT_BLOCK 4
|
|
|
|
|
#define YMM_FLOAT_BLOCK 8
|
|
|
|
|
#define ZMM_FLOAT_BLOCK 16
|
|
|
|
|
|
|
|
|
|
#define ALIGN32 __attribute__((aligned(32)))
|
|
|
|
|
#define EXP_HIG 88.3762626647949f
|
|
|
|
|
#define EXP_LOW -88.3762626647949f
|
|
|
|
|
#define CEPHES_LOG2EF 1.44269504088896341
|
|
|
|
|
#define CEPHES_EXP_C1 0.693359375
|
|
|
|
|
#define CEPHES_EXP_C2 -2.12194440e-4
|
|
|
|
|
#define CEPHES_EXP_P0 1.9875691500E-4
|
|
|
|
|
#define CEPHES_EXP_P1 1.3981999507E-3
|
|
|
|
|
#define CEPHES_EXP_P2 8.3334519073E-3
|
|
|
|
|
#define CEPHES_EXP_P3 4.1665795894E-2
|
|
|
|
|
#define CEPHES_EXP_P4 1.6666665459E-1
|
|
|
|
|
#define CEPHES_EXP_P5 5.0000001201E-1
|
|
|
|
|
|
|
|
|
|
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
|
|
|
|
|
|
|
|
|
|
#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
|
|
|
|
|
|
|
|
|
|
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
|
|
|
|
|
class VXXJitCode : public JitCode {
|
|
|
|
|
public:
|
|
|
|
@ -134,10 +181,87 @@ class VActJitCode : public JitCode {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute exp with ymm, xmm
|
|
|
|
|
void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
|
|
|
|
|
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
|
|
|
|
|
void exp_xmm(const Xbyak::Xmm& dst, const Xbyak::Xmm& src, int fx_idx = 2,
|
|
|
|
|
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
|
|
|
|
|
template <typename JMM>
|
|
|
|
|
void exp_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT
|
|
|
|
|
int mask_idx = 4, int tmp_idx = 5) {
|
|
|
|
|
using namespace platform::jit; // NOLINT
|
|
|
|
|
assert(src.getIdx() != dst.getIdx()); // TODO(TJ): use enfore
|
|
|
|
|
// check all idx can not equal
|
|
|
|
|
JMM jmm_fx = JMM(fx_idx);
|
|
|
|
|
JMM jmm_fy = JMM(fy_idx);
|
|
|
|
|
JMM jmm_mask = JMM(mask_idx);
|
|
|
|
|
JMM jmm_tmp = JMM(tmp_idx);
|
|
|
|
|
reg64_t reg_ptr_global = rax;
|
|
|
|
|
push(reg_ptr_global);
|
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
|
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
|
|
|
|
|
vminps(src, src, jmm_tmp);
|
|
|
|
|
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
|
|
|
|
|
vmaxps(src, src, jmm_tmp);
|
|
|
|
|
// express exp(x) as exp(g + n*log(2))
|
|
|
|
|
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
|
|
|
|
|
vmulps(jmm_fx, src, jmm_tmp);
|
|
|
|
|
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
|
|
|
|
|
vaddps(jmm_fx, jmm_fx, jmm_tmp);
|
|
|
|
|
vroundps(jmm_fy, jmm_fx, 0x01);
|
|
|
|
|
// if greater, substract 1
|
|
|
|
|
vcmpgtps(jmm_mask, jmm_fy, jmm_fx);
|
|
|
|
|
vmovaps(jmm_tmp, ptr[reg_ptr_global]);
|
|
|
|
|
vandps(jmm_mask, jmm_mask, jmm_tmp);
|
|
|
|
|
vsubps(jmm_fx, jmm_fy, jmm_mask);
|
|
|
|
|
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
|
|
|
|
|
vmulps(jmm_fy, jmm_fx, jmm_tmp);
|
|
|
|
|
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
|
|
|
|
|
JMM ymm_z = JMM(jmm_mask.getIdx());
|
|
|
|
|
vmulps(ymm_z, jmm_fx, jmm_tmp);
|
|
|
|
|
vsubps(src, src, jmm_fy);
|
|
|
|
|
vsubps(src, src, ymm_z);
|
|
|
|
|
vmulps(ymm_z, src, src);
|
|
|
|
|
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
|
|
|
|
|
vmulps(dst, src, jmm_tmp);
|
|
|
|
|
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
|
|
|
|
|
i += (YMM_FLOAT_BLOCK * sizeof(float))) {
|
|
|
|
|
vmovaps(jmm_tmp, ptr[reg_ptr_global + i]); // P1~P4
|
|
|
|
|
vaddps(dst, dst, jmm_tmp);
|
|
|
|
|
vmulps(dst, dst, src);
|
|
|
|
|
}
|
|
|
|
|
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
|
|
|
|
|
vaddps(dst, dst, jmm_tmp);
|
|
|
|
|
vmulps(dst, dst, ymm_z);
|
|
|
|
|
vaddps(dst, dst, src);
|
|
|
|
|
vmovaps(jmm_tmp, ptr[reg_ptr_global]);
|
|
|
|
|
vaddps(dst, dst, jmm_tmp);
|
|
|
|
|
// build 2^n
|
|
|
|
|
JMM ymm_int = jmm_fx;
|
|
|
|
|
vcvttps2dq(ymm_int, jmm_fx);
|
|
|
|
|
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_int_0x7f));
|
|
|
|
|
vmovdqa(jmm_tmp, ptr[reg_ptr_global]);
|
|
|
|
|
if (MayIUse(avx2) || std::is_same<JMM, xmm_t>::value) {
|
|
|
|
|
vpaddd(ymm_int, ymm_int, jmm_tmp);
|
|
|
|
|
vpslld(ymm_int, ymm_int, 23);
|
|
|
|
|
} else if (MayIUse(avx)) {
|
|
|
|
|
xmm_t xtmp1 = xmm_t(ymm_int.getIdx());
|
|
|
|
|
xmm_t xtmp2 = xmm_t(jmm_tmp.getIdx());
|
|
|
|
|
reg64_t reg_ptr_tmp = reg_ptr_global;
|
|
|
|
|
mov(reg_ptr_tmp, reinterpret_cast<size_t>(g_tmp_mem));
|
|
|
|
|
vmovdqa(ptr[reg_ptr_tmp], ymm_int);
|
|
|
|
|
vmovdqa(ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof(float)], jmm_tmp);
|
|
|
|
|
vpaddd(xtmp1, xtmp1, xtmp2);
|
|
|
|
|
vpslld(xtmp1, xtmp1, 23);
|
|
|
|
|
vmovdqa(ptr[reg_ptr_tmp], xtmp1);
|
|
|
|
|
// next 128bits
|
|
|
|
|
vmovdqa(xtmp1, ptr[reg_ptr_tmp + XMM_FLOAT_BLOCK * sizeof(float)]);
|
|
|
|
|
vmovdqa(xtmp2, ptr[reg_ptr_tmp +
|
|
|
|
|
(YMM_FLOAT_BLOCK + XMM_FLOAT_BLOCK) * sizeof(float)]);
|
|
|
|
|
vpaddd(xtmp1, xtmp1, xtmp2);
|
|
|
|
|
vpslld(xtmp1, xtmp1, 23);
|
|
|
|
|
vmovdqa(ptr[reg_ptr_tmp + XMM_FLOAT_BLOCK * sizeof(float)], xtmp1);
|
|
|
|
|
// load out
|
|
|
|
|
vmovdqa(ymm_int, ptr[reg_ptr_tmp]);
|
|
|
|
|
}
|
|
|
|
|
vmulps(dst, dst, ymm_int);
|
|
|
|
|
pop(reg_ptr_global);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute sigmoid with ymm
|
|
|
|
|
void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
|
|
|
|
|