Merge pull request #14321 from tensor-tang/fea/jit/vscal

Fea jitcode vscal vaddbias
revert-14324-fix_vlog
tensor-tang 6 years ago committed by GitHub
commit 22125ebaef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,21 +24,30 @@ namespace gen {
using namespace platform::jit; // NOLINT
bool VVVJitCode::init(int d) {
bool VXXJitCode::init(int d, int scalar_index) {
// It's not necessary to use avx512 since it would slow down the frequency
// and this kernel is not compute bound.
return MayIUse(avx);
return MayIUse(avx) && scalar_index >= 0 && scalar_index <= 2;
}
void VVVJitCode::generate() {
void VXXJitCode::generate() {
// do not need push stack, and do not need save avx512reg if do not use avx512
int offset = 0;
if (with_relu_) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
if (scalar_index_ == 1) {
vbroadcastss(ymm_src1, ptr[param1]);
} else if (scalar_index_ == 2) {
vbroadcastss(ymm_src2, ptr[param2]);
}
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src1, ptr[param1 + offset]);
vmovups(ymm_src2, ptr[param2 + offset]);
if (scalar_index_ != 1) {
vmovups(ymm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(ymm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) {
vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::add) {
@ -52,8 +61,12 @@ void VVVJitCode::generate() {
}
int rest = num_ % AVX_FLOAT_BLOCK;
if (rest >= 4) {
vmovups(xmm_src1, ptr[param1 + offset]);
vmovups(xmm_src2, ptr[param2 + offset]);
if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
@ -67,8 +80,12 @@ void VVVJitCode::generate() {
rest -= 4;
}
if (rest >= 2) {
vmovq(xmm_src1, ptr[param1 + offset]);
vmovq(xmm_src2, ptr[param2 + offset]);
if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
@ -82,8 +99,12 @@ void VVVJitCode::generate() {
rest -= 2;
}
if (rest > 0) {
vmovss(xmm_src1, ptr[param1 + offset]);
vmovss(xmm_src2, ptr[param2 + offset]);
if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) {
vmulss(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
@ -96,6 +117,7 @@ void VVVJitCode::generate() {
}
ret();
}
} // namespace gen
} // namespace jitkernel
} // namespace math

@ -29,33 +29,46 @@ using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;
// function: vec = Operand(vec, vec) (maybe with relu)
typedef enum { mul = 0, add } operand_type;
class VVVJitCode : public JitCode {
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode {
public:
const char* name() const override {
std::string base = "VVVJitCode";
std::string base = "VXXJitCode";
if (scalar_index_ == 1) {
base += "_Scalar";
} else {
base += "_Vec";
}
if (type_ == operand_type::mul) {
base += "_Mul";
} else if (type_ == operand_type::add) {
base += "_Add";
}
base += (with_relu_ ? "_relu" : "");
if (scalar_index_ == 2) {
base += "_Scalar";
} else {
base += "_Vec";
}
base += (with_relu_ ? "_Relu" : "");
return base.c_str();
}
explicit VVVJitCode(int d, operand_type type, bool with_relu,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
explicit VXXJitCode(int d, operand_type type, int scalar_index,
bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {}
static bool init(int d);
static bool init(int d, int scalar_index = 0);
void generate() override;
private:
int num_;
operand_type type_;
int scalar_index_;
bool with_relu_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
@ -63,13 +76,13 @@ class VVVJitCode : public JitCode {
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(1);
xmm_t xmm_zero = xmm_t(2);
xmm_t xmm_dst = xmm_t(2);
xmm_t xmm_zero = xmm_t(3);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(1);
ymm_t ymm_zero = ymm_t(2);
ymm_t ymm_dst = ymm_t(2);
ymm_t ymm_zero = ymm_t(3);
};
} // namespace gen

@ -83,14 +83,15 @@ class VAddReluKernel : public Kernel {
template <typename T>
class VScalKernel : public Kernel {
public:
virtual void Compute(const T a, const T *x, T *y) const = 0;
virtual void Compute(const T a, T *x) const = 0;
// y = a.*x
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>
class VAddBiasKernel : public Kernel {
public:
virtual void Compute(const T a, const T *x, T *y) const = 0;
// y = a.+x
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>

File diff suppressed because it is too large Load Diff

@ -409,10 +409,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
}
void Compute(const T* x, T* y) const override {
vscal_->Compute(static_cast<T>(2), x, y);
const T a = static_cast<T>(2), b = static_cast<T>(-1);
vscal_->Compute(&a, x, y, this->num_);
vsigmoid_->Compute(y, y);
vscal_->Compute(static_cast<T>(2), y);
vaddbias_->Compute(static_cast<T>(-1), y, y);
vscal_->Compute(&a, y, y, this->num_);
vaddbias_->Compute(&b, y, y, this->num_);
}
private:
@ -472,10 +473,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
vscal_->Compute(2.f, x, y); \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(2.f, y); \
vaddbias_->Compute(-1.f, y, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(&b, y, y, this->num_); \
}
#define INTRI_GT16_FLOAT(isa, expisa) \
@ -502,10 +504,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
} \
x += this->end_; \
y += this->end_; \
vscal_->Compute(2.f, x, y); \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(2.f, y); \
vaddbias_->Compute(-1.f, y, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(&b, y, y, this->num_); \
}
#ifdef __AVX__

@ -128,7 +128,7 @@ TEST(JitKernel, vaddbias) {
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(a, x_data, ztgt_data);
ker->Compute(&a, x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
@ -281,10 +281,11 @@ void vtanh_better(
const paddle::operators::math::jitkernel::VAddBiasKernel<float>>&
vaddbias,
const int n, const float* x, float* y) {
vscal->Compute(2.f, x, y);
const float a = 2.f, b = -1.f;
vscal->Compute(&a, x, y, n);
vsigmoid->Compute(y, y);
vscal->Compute(2.f, y);
vaddbias->Compute(-1.f, y, y);
vscal->Compute(&a, y, y, n);
vaddbias->Compute(&b, y, y, n);
}
TEST(JitKernel, vtanh) {
@ -531,12 +532,12 @@ TEST(JitKernel, vscal) {
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(a, x_data, ztgt_data);
ker->Compute(&a, x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
auto ttgts1 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(a, y_data);
ker->Compute(&a, y_data, y_data, d);
}
auto ttgte1 = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat

Loading…
Cancel
Save