|
|
|
@ -17,6 +17,7 @@
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "glog/logging.h"
|
|
|
|
|
#include "paddle/fluid/operators/jit/gen/jitcode.h"
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -45,8 +46,6 @@ class SeqPoolJitCode : public JitCode {
|
|
|
|
|
base += "_Sqrt";
|
|
|
|
|
}
|
|
|
|
|
base += ("_W" + std::to_string(w_));
|
|
|
|
|
// TODO(TJ): make h load from params
|
|
|
|
|
base += ("_H" + std::to_string(h_));
|
|
|
|
|
return base.c_str();
|
|
|
|
|
}
|
|
|
|
|
void genCode() override;
|
|
|
|
@ -54,25 +53,36 @@ class SeqPoolJitCode : public JitCode {
|
|
|
|
|
protected:
|
|
|
|
|
template <typename JMM>
|
|
|
|
|
void pool_height(int w_offset, int block, int max_num_regs) {
|
|
|
|
|
for (int h = 0; h < h_; ++h) {
|
|
|
|
|
int offset = h * w_ * sizeof(float) + w_offset;
|
|
|
|
|
const int shift_regs = (h == 0) ? 0 : max_num_regs;
|
|
|
|
|
for (int i = 0; i < max_num_regs; ++i) {
|
|
|
|
|
vmovups(JMM(i + shift_regs), ptr[param1 + offset]);
|
|
|
|
|
offset += sizeof(float) * block;
|
|
|
|
|
}
|
|
|
|
|
if (h > 0) {
|
|
|
|
|
// sum anyway
|
|
|
|
|
int offset = w_offset;
|
|
|
|
|
for (int i = 0; i < max_num_regs; ++i) {
|
|
|
|
|
vmovups(JMM(i), ptr[param1 + offset]);
|
|
|
|
|
offset += sizeof(float) * block;
|
|
|
|
|
}
|
|
|
|
|
if (h_ > 1) {
|
|
|
|
|
Label l_next_h;
|
|
|
|
|
mov(reg_h, 1);
|
|
|
|
|
mov(reg_tmp, param1);
|
|
|
|
|
add(reg_tmp, w_ * sizeof(float) + w_offset);
|
|
|
|
|
L(l_next_h);
|
|
|
|
|
{
|
|
|
|
|
mov(reg_ptr_src_i, reg_tmp);
|
|
|
|
|
for (int i = 0; i < max_num_regs; ++i) {
|
|
|
|
|
vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]);
|
|
|
|
|
// sum anyway
|
|
|
|
|
vaddps(JMM(i), JMM(i), JMM(i + max_num_regs));
|
|
|
|
|
add(reg_ptr_src_i, sizeof(float) * block);
|
|
|
|
|
}
|
|
|
|
|
inc(reg_h);
|
|
|
|
|
add(reg_tmp, w_ * sizeof(float));
|
|
|
|
|
cmp(reg_h, h_);
|
|
|
|
|
jl(l_next_h, T_NEAR);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// save right now
|
|
|
|
|
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
|
|
|
|
|
vbroadcastss(JMM(max_num_regs), reg32_scalar);
|
|
|
|
|
}
|
|
|
|
|
int offset = w_offset;
|
|
|
|
|
offset = w_offset;
|
|
|
|
|
for (int i = 0; i < max_num_regs; ++i) {
|
|
|
|
|
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
|
|
|
|
|
vmulps(JMM(i), JMM(i), JMM(max_num_regs));
|
|
|
|
@ -82,6 +92,102 @@ class SeqPoolJitCode : public JitCode {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void pool_height_of_rest_width(int rest, int w_offset, int max_num_regs) {
|
|
|
|
|
const int rest_used_num_regs = load_rest(rest, w_offset, 0);
|
|
|
|
|
const bool has_block4 = rest / 4 > 0;
|
|
|
|
|
const bool has_block2 = (rest % 4) / 2 > 0;
|
|
|
|
|
const bool has_block1 = (rest % 2) == 1;
|
|
|
|
|
if (h_ > 1) {
|
|
|
|
|
Label l_next_h;
|
|
|
|
|
mov(reg_h, 1);
|
|
|
|
|
mov(reg_tmp, param1);
|
|
|
|
|
add(reg_tmp, w_ * sizeof(float) + w_offset);
|
|
|
|
|
L(l_next_h);
|
|
|
|
|
{
|
|
|
|
|
// int used_regs =load_rest(rest, h * w_ * sizeof(float) + w_offset,
|
|
|
|
|
// max_num_regs);
|
|
|
|
|
int reg_idx = 0;
|
|
|
|
|
mov(reg_ptr_src_i, reg_tmp);
|
|
|
|
|
if (has_block4) {
|
|
|
|
|
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
|
|
|
|
|
add(reg_ptr_src_i, sizeof(float) * 4);
|
|
|
|
|
reg_idx++;
|
|
|
|
|
}
|
|
|
|
|
if (has_block2) {
|
|
|
|
|
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
|
|
|
|
|
add(reg_ptr_src_i, sizeof(float) * 2);
|
|
|
|
|
reg_idx++;
|
|
|
|
|
}
|
|
|
|
|
if (has_block1) {
|
|
|
|
|
vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
|
|
|
|
|
reg_idx++;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs,
|
|
|
|
|
"All heights should use same regs");
|
|
|
|
|
for (int i = 0; i < reg_idx; ++i) {
|
|
|
|
|
vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
|
|
|
|
|
}
|
|
|
|
|
inc(reg_h);
|
|
|
|
|
add(reg_tmp, w_ * sizeof(float));
|
|
|
|
|
cmp(reg_h, h_);
|
|
|
|
|
jl(l_next_h, T_NEAR);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// save right now
|
|
|
|
|
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
|
|
|
|
|
vbroadcastss(xmm_t(max_num_regs - 1), reg32_scalar);
|
|
|
|
|
for (int i = 0; i < rest_used_num_regs; ++i) {
|
|
|
|
|
vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs - 1));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
save_rest(rest, w_offset);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// return the number of used regs, use start from reg 0
|
|
|
|
|
int load_rest(int rest, int w_offset, const int num_shift_regs,
|
|
|
|
|
const int reg_start = 0) {
|
|
|
|
|
const bool has_block4 = rest / 4 > 0;
|
|
|
|
|
const bool has_block2 = (rest % 4) / 2 > 0;
|
|
|
|
|
const bool has_block1 = (rest % 2) == 1;
|
|
|
|
|
int reg_idx = reg_start;
|
|
|
|
|
if (has_block4) {
|
|
|
|
|
vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
|
|
|
|
|
w_offset += sizeof(float) * 4;
|
|
|
|
|
reg_idx++;
|
|
|
|
|
}
|
|
|
|
|
if (has_block2) {
|
|
|
|
|
vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
|
|
|
|
|
w_offset += sizeof(float) * 2;
|
|
|
|
|
reg_idx++;
|
|
|
|
|
}
|
|
|
|
|
if (has_block1) {
|
|
|
|
|
vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param1 + w_offset]);
|
|
|
|
|
reg_idx++;
|
|
|
|
|
}
|
|
|
|
|
return reg_idx;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// use reg start from 0
|
|
|
|
|
void save_rest(int rest, int w_offset, int reg_start = 0) {
|
|
|
|
|
const bool has_block4 = rest / 4 > 0;
|
|
|
|
|
const bool has_block2 = (rest % 4) / 2 > 0;
|
|
|
|
|
const bool has_block1 = (rest % 2) == 1;
|
|
|
|
|
int reg_idx = reg_start;
|
|
|
|
|
if (has_block4) {
|
|
|
|
|
vmovups(ptr[param2 + w_offset], xmm_t(reg_idx));
|
|
|
|
|
w_offset += sizeof(float) * 4;
|
|
|
|
|
reg_idx++;
|
|
|
|
|
}
|
|
|
|
|
if (has_block2) {
|
|
|
|
|
vmovq(ptr[param2 + w_offset], xmm_t(reg_idx));
|
|
|
|
|
w_offset += sizeof(float) * 2;
|
|
|
|
|
reg_idx++;
|
|
|
|
|
}
|
|
|
|
|
if (has_block1) {
|
|
|
|
|
vmovss(ptr[param2 + w_offset], xmm_t(reg_idx));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
int h_;
|
|
|
|
|
int w_;
|
|
|
|
@ -90,6 +196,10 @@ class SeqPoolJitCode : public JitCode {
|
|
|
|
|
reg64_t param2{abi_param2};
|
|
|
|
|
reg64_t param3{abi_param3};
|
|
|
|
|
reg32_t reg32_scalar{r8d};
|
|
|
|
|
|
|
|
|
|
reg64_t reg_h{r9};
|
|
|
|
|
reg64_t reg_ptr_src_i{r10};
|
|
|
|
|
reg64_t reg_tmp{r11};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace gen
|
|
|
|
|