Merge pull request #14265 from tensor-tang/fea/jit/vadd

add vadd, vaddrelu jitcode
revert-14324-fix_vlog
tensor-tang 6 years ago committed by GitHub
commit e8642c3c1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -36,7 +36,7 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
.template Get<jitkernel::VAddReluKernel<T>>(N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
vaddrelu->Compute(B, dst, dst);
vaddrelu->Compute(B, dst, dst, N);
}
} else {
const auto& vadd = jitkernel::KernelPool::Instance()
@ -47,7 +47,7 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
#endif
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
vadd->Compute(B, dst, dst);
vadd->Compute(B, dst, dst, N);
}
}
}

@ -24,19 +24,29 @@ namespace gen {
using namespace platform::jit; // NOLINT
bool VMulJitCode::init(int d) {
bool VVVJitCode::init(int d) {
// It's not necessary to use avx512 since it would slow down the frequency
// and this kernel is not compute bound.
return MayIUse(avx);
}
void VMulJitCode::generate() {
void VVVJitCode::generate() {
// do not need push stack, and do not need save avx512reg if do not use avx512
int offset = 0;
if (with_relu_) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src1, ptr[param1 + offset]);
vmovups(ymm_src2, ptr[param2 + offset]);
if (type_ == operand_type::mul) {
vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::add) {
vaddps(ymm_dst, ymm_src1, ymm_src2);
}
if (with_relu_) {
vmaxps(ymm_dst, ymm_zero, ymm_dst);
}
vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK;
}
@ -44,7 +54,14 @@ void VMulJitCode::generate() {
if (rest >= 4) {
vmovups(xmm_src1, ptr[param1 + offset]);
vmovups(xmm_src2, ptr[param2 + offset]);
if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddps(xmm_dst, xmm_src1, xmm_src2);
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovups(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
@ -52,7 +69,14 @@ void VMulJitCode::generate() {
if (rest >= 2) {
vmovq(xmm_src1, ptr[param1 + offset]);
vmovq(xmm_src2, ptr[param2 + offset]);
if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddps(xmm_dst, xmm_src1, xmm_src2);
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovq(ptr[param3 + offset], xmm_dst);
offset += sizeof(float) * 2;
rest -= 2;
@ -60,12 +84,18 @@ void VMulJitCode::generate() {
if (rest > 0) {
vmovss(xmm_src1, ptr[param1 + offset]);
vmovss(xmm_src2, ptr[param2 + offset]);
if (type_ == operand_type::mul) {
vmulss(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
vaddss(xmm_dst, xmm_src1, xmm_src2);
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
vmovss(ptr[param3 + offset], xmm_dst);
}
ret();
}
} // namespace gen
} // namespace jitkernel
} // namespace math

@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/operators/math/jit_gen.h"
namespace paddle {
namespace operators {
namespace math {
@ -29,28 +29,47 @@ using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;
class VMulJitCode : public JitCode {
// function: vec = Operand(vec, vec) (maybe with relu)
typedef enum { mul = 0, add } operand_type;
class VVVJitCode : public JitCode {
public:
DECLARE_JIT_CODE(VMulJitCode);
explicit VMulJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d) {}
const char* name() const override {
std::string base = "VVVJitCode";
if (type_ == operand_type::mul) {
base += "_Mul";
} else if (type_ == operand_type::add) {
base += "_Add";
}
base += (with_relu_ ? "_relu" : "");
return base.c_str();
}
explicit VVVJitCode(int d, operand_type type, bool with_relu,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
with_relu_(with_relu) {}
static bool init(int d);
void generate() override;
private:
int num_;
operand_type type_;
bool with_relu_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(2);
xmm_t xmm_dst = xmm_t(1);
xmm_t xmm_zero = xmm_t(2);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(2);
ymm_t ymm_dst = ymm_t(1);
ymm_t ymm_zero = ymm_t(2);
};
} // namespace gen

@ -71,26 +71,26 @@ class VMulKernel : public Kernel {
template <typename T>
class VAddKernel : public Kernel {
public:
virtual void Compute(const T *x, const T *y, T *z) const = 0;
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>
class VScalKernel : public Kernel {
class VAddReluKernel : public Kernel {
public:
virtual void Compute(const T a, const T *x, T *y) const = 0;
virtual void Compute(const T a, T *x) const = 0;
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>
class VAddBiasKernel : public Kernel {
class VScalKernel : public Kernel {
public:
virtual void Compute(const T a, const T *x, T *y) const = 0;
virtual void Compute(const T a, T *x) const = 0;
};
template <typename T>
class VAddReluKernel : public Kernel {
class VAddBiasKernel : public Kernel {
public:
virtual void Compute(const T *x, const T *y, T *z) const = 0;
virtual void Compute(const T a, const T *x, T *y) const = 0;
};
template <typename T>

File diff suppressed because it is too large Load Diff

@ -181,7 +181,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
act_cand_d_->Compute(gates, gates);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_);
@ -291,16 +291,16 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
/* get fgated and igated*/
vmul_d_->Compute(wp_data, ct_1, checked, d_);
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
act_gate_d2_->Compute(gates + d_, gates + d_);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_->Compute(gates, gates);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* get ogated*/
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
act_gate_d_->Compute(gates + d3_, gates + d3_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_);
@ -314,7 +314,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
vmul_d_->Compute(gates, gates + d_, ct, d_);
/* get outgated, put W_oc * C_t on igated */
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_);
act_cell_d_->Compute(ct, gates + d2_);

@ -371,7 +371,7 @@ void lstm_ctht_better(
vtanh_d->Compute(gates, gates);
vmul_d->Compute(gates, gates + d, gates + d, d);
vmul_d->Compute(ct_1, gates + d2, gates + d2, d);
vadd_d->Compute(gates + d, gates + d2, ct);
vadd_d->Compute(gates + d, gates + d2, ct, d);
/* H_t = act_cell(C_t) * ogated */
vtanh_d->Compute(ct, gates + d2);
vmul_d->Compute(gates + d2, gates + d * 3, ht, d);
@ -695,7 +695,7 @@ TEST(JitKernel, vadd) {
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, y_data, ztgt_data);
ker->Compute(x_data, y_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
@ -723,8 +723,8 @@ void vaddrelu_better(
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
const float* x, const float* y, float* z) {
vadd->Compute(x, y, z);
const float* x, const float* y, float* z, int d) {
vadd->Compute(x, y, z, d);
vrelu->Compute(z, z);
}
@ -752,12 +752,12 @@ TEST(JitKernel, vaddrelu) {
auto trefe = GetCurrentUS();
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data);
vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data, d);
}
auto tmkle = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, y_data, ztgt_data);
ker->Compute(x_data, y_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat

Loading…
Cancel
Save