refine exp jitcode with all size

test=develop
panyx0718-patch-1
tensor-tang 6 years ago
parent d3eae8f61b
commit ccb8963705

File diff suppressed because it is too large Load Diff

@ -16,6 +16,8 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/math/jit_gen.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace math {
@ -40,6 +42,51 @@ typedef enum {
identity
} operand_type;
extern const float exp_float_consts[];
extern const int exp_int_0x7f[];
extern int g_tmp_mem[];
// TODO(TJ): move these to some proper place
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341
#define CEPHES_EXP_C1 0.693359375
#define CEPHES_EXP_C2 -2.12194440e-4
#define CEPHES_EXP_P0 1.9875691500E-4
#define CEPHES_EXP_P1 1.3981999507E-3
#define CEPHES_EXP_P2 8.3334519073E-3
#define CEPHES_EXP_P3 4.1665795894E-2
#define CEPHES_EXP_P4 1.6666665459E-1
#define CEPHES_EXP_P5 5.0000001201E-1
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode {
public:
@ -134,10 +181,87 @@ class VActJitCode : public JitCode {
}
// compute exp with ymm, xmm
void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
void exp_xmm(const Xbyak::Xmm& dst, const Xbyak::Xmm& src, int fx_idx = 2,
int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5);
template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT
int mask_idx = 4, int tmp_idx = 5) {
using namespace platform::jit; // NOLINT
assert(src.getIdx() != dst.getIdx()); // TODO(TJ): use enfore
// check all idx can not equal
JMM jmm_fx = JMM(fx_idx);
JMM jmm_fy = JMM(fy_idx);
JMM jmm_mask = JMM(mask_idx);
JMM jmm_tmp = JMM(tmp_idx);
reg64_t reg_ptr_global = rax;
push(reg_ptr_global);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]);
vminps(src, src, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]);
vmaxps(src, src, jmm_tmp);
// express exp(x) as exp(g + n*log(2))
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]);
vmulps(jmm_fx, src, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]);
vaddps(jmm_fx, jmm_fx, jmm_tmp);
vroundps(jmm_fy, jmm_fx, 0x01);
// if greater, substract 1
vcmpgtps(jmm_mask, jmm_fy, jmm_fx);
vmovaps(jmm_tmp, ptr[reg_ptr_global]);
vandps(jmm_mask, jmm_mask, jmm_tmp);
vsubps(jmm_fx, jmm_fy, jmm_mask);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]);
vmulps(jmm_fy, jmm_fx, jmm_tmp);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]);
JMM ymm_z = JMM(jmm_mask.getIdx());
vmulps(ymm_z, jmm_fx, jmm_tmp);
vsubps(src, src, jmm_fy);
vsubps(src, src, ymm_z);
vmulps(ymm_z, src, src);
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
vmulps(dst, src, jmm_tmp);
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
i += (YMM_FLOAT_BLOCK * sizeof(float))) {
vmovaps(jmm_tmp, ptr[reg_ptr_global + i]); // P1~P4
vaddps(dst, dst, jmm_tmp);
vmulps(dst, dst, src);
}
vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]);
vaddps(dst, dst, jmm_tmp);
vmulps(dst, dst, ymm_z);
vaddps(dst, dst, src);
vmovaps(jmm_tmp, ptr[reg_ptr_global]);
vaddps(dst, dst, jmm_tmp);
// build 2^n
JMM ymm_int = jmm_fx;
vcvttps2dq(ymm_int, jmm_fx);
mov(reg_ptr_global, reinterpret_cast<size_t>(exp_int_0x7f));
vmovdqa(jmm_tmp, ptr[reg_ptr_global]);
if (MayIUse(avx2) || std::is_same<JMM, xmm_t>::value) {
vpaddd(ymm_int, ymm_int, jmm_tmp);
vpslld(ymm_int, ymm_int, 23);
} else if (MayIUse(avx)) {
xmm_t xtmp1 = xmm_t(ymm_int.getIdx());
xmm_t xtmp2 = xmm_t(jmm_tmp.getIdx());
reg64_t reg_ptr_tmp = reg_ptr_global;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(g_tmp_mem));
vmovdqa(ptr[reg_ptr_tmp], ymm_int);
vmovdqa(ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof(float)], jmm_tmp);
vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_tmp], xtmp1);
// next 128bits
vmovdqa(xtmp1, ptr[reg_ptr_tmp + XMM_FLOAT_BLOCK * sizeof(float)]);
vmovdqa(xtmp2, ptr[reg_ptr_tmp +
(YMM_FLOAT_BLOCK + XMM_FLOAT_BLOCK) * sizeof(float)]);
vpaddd(xtmp1, xtmp1, xtmp2);
vpslld(xtmp1, xtmp1, 23);
vmovdqa(ptr[reg_ptr_tmp + XMM_FLOAT_BLOCK * sizeof(float)], xtmp1);
// load out
vmovdqa(ymm_int, ptr[reg_ptr_tmp]);
}
vmulps(dst, dst, ymm_int);
pop(reg_ptr_global);
}
// compute sigmoid with ymm
void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2,

@ -26,6 +26,7 @@ namespace operators {
namespace math {
namespace jitkernel {
// TODO(TJ): move these to some proper place
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0

Loading…
Cancel
Save