|
|
|
|
@ -16,7 +16,6 @@
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "glog/logging.h"
|
|
|
|
|
#include "paddle/fluid/operators/jit/gen/act.h" // for ones
|
|
|
|
|
#include "paddle/fluid/operators/jit/gen/jitcode.h"
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
|
|
|
|
|
|
@ -31,9 +30,11 @@ class SeqPoolJitCode : public JitCode {
|
|
|
|
|
size_t code_size = 256 * 1024,
|
|
|
|
|
void* code_ptr = nullptr)
|
|
|
|
|
: JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) {
|
|
|
|
|
if (type_ != SeqPoolType::kSum) {
|
|
|
|
|
if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg ||
|
|
|
|
|
type_ == SeqPoolType::kSqrt)) {
|
|
|
|
|
LOG(FATAL) << "Only support sum pool yet ";
|
|
|
|
|
}
|
|
|
|
|
fp_h_[0] = 1.f;
|
|
|
|
|
this->genCode();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -82,15 +83,8 @@ class SeqPoolJitCode : public JitCode {
|
|
|
|
|
L(l_h_done);
|
|
|
|
|
// save right now
|
|
|
|
|
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
|
|
|
|
|
mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
|
vmovups(JMM(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]);
|
|
|
|
|
movd(JMM(max_num_regs + 1), reg32_fp_h);
|
|
|
|
|
if (type_ == SeqPoolType::kSqrt) {
|
|
|
|
|
vsqrtps(JMM(max_num_regs + 1), JMM(max_num_regs + 1));
|
|
|
|
|
}
|
|
|
|
|
vdivps(JMM(max_num_regs + 2), JMM(max_num_regs), JMM(max_num_regs + 1));
|
|
|
|
|
vbroadcastss(JMM(max_num_regs),
|
|
|
|
|
JMM(max_num_regs + 2)); // TODO(TJ): fix me
|
|
|
|
|
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
|
|
|
|
|
vbroadcastss(JMM(max_num_regs), ptr[reg_tmp]);
|
|
|
|
|
}
|
|
|
|
|
offset = w_offset;
|
|
|
|
|
for (int i = 0; i < max_num_regs; ++i) {
|
|
|
|
|
@ -144,15 +138,8 @@ class SeqPoolJitCode : public JitCode {
|
|
|
|
|
L(l_h_done);
|
|
|
|
|
// save right now
|
|
|
|
|
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
|
|
|
|
|
mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
|
|
|
|
|
vmovups(xmm_t(max_num_regs), ptr[reg_tmp + OFFSET_EXP_ONE]);
|
|
|
|
|
movd(xmm_t(max_num_regs + 1), reg32_fp_h);
|
|
|
|
|
if (type_ == SeqPoolType::kSqrt) {
|
|
|
|
|
vsqrtps(xmm_t(max_num_regs + 1), xmm_t(max_num_regs + 1));
|
|
|
|
|
}
|
|
|
|
|
vdivps(xmm_t(max_num_regs + 2), xmm_t(max_num_regs),
|
|
|
|
|
xmm_t(max_num_regs + 1));
|
|
|
|
|
vbroadcastss(xmm_t(max_num_regs), xmm_t(max_num_regs + 2));
|
|
|
|
|
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
|
|
|
|
|
vbroadcastss(xmm_t(max_num_regs), ptr[reg_tmp]);
|
|
|
|
|
for (int i = 0; i < rest_used_num_regs; ++i) {
|
|
|
|
|
vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs));
|
|
|
|
|
}
|
|
|
|
|
@ -206,6 +193,7 @@ class SeqPoolJitCode : public JitCode {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
float ALIGN32_BEG fp_h_[1] ALIGN32_END;
|
|
|
|
|
int w_;
|
|
|
|
|
SeqPoolType type_;
|
|
|
|
|
reg64_t param_src{abi_param1};
|
|
|
|
|
|