|
|
|
@ -82,6 +82,26 @@ __m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static std::shared_ptr<const VActKernel<T>> GetActKernel(
|
|
|
|
|
const std::string& type, int n) {
|
|
|
|
|
if (type == "sigmoid") {
|
|
|
|
|
return std::dynamic_pointer_cast<const VActKernel<T>>(
|
|
|
|
|
KernelPool::Instance().template Get<VSigmoidKernel<T>>(n));
|
|
|
|
|
} else if (type == "relu") {
|
|
|
|
|
return std::dynamic_pointer_cast<const VActKernel<T>>(
|
|
|
|
|
KernelPool::Instance().template Get<VReluKernel<T>>(n));
|
|
|
|
|
} else if (type == "tanh") {
|
|
|
|
|
return std::dynamic_pointer_cast<const VActKernel<T>>(
|
|
|
|
|
KernelPool::Instance().template Get<VTanhKernel<T>>(n));
|
|
|
|
|
} else if (type == "identity" || type == "") {
|
|
|
|
|
return std::dynamic_pointer_cast<const VActKernel<T>>(
|
|
|
|
|
KernelPool::Instance().template Get<VIdentityKernel<T>>(n));
|
|
|
|
|
}
|
|
|
|
|
PADDLE_THROW("Not support type: %s", type);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/* LSTM JitKernel */
|
|
|
|
|
template <typename T, jit::cpu_isa_t isa, jit_block>
|
|
|
|
|
class LSTMKernelImpl : public LSTMKernel<T> {
|
|
|
|
@ -93,26 +113,10 @@ class LSTMKernelImpl : public LSTMKernel<T> {
|
|
|
|
|
d_ = d;
|
|
|
|
|
d2_ = d * 2;
|
|
|
|
|
d3_ = d * 3;
|
|
|
|
|
auto GetActKernel = [&](const std::string& type,
|
|
|
|
|
int n) -> std::shared_ptr<const VActKernel<T>> {
|
|
|
|
|
if (type == "sigmoid") {
|
|
|
|
|
return std::dynamic_pointer_cast<const VActKernel<T>>(
|
|
|
|
|
KernelPool::Instance().template Get<VSigmoidKernel<T>>(n));
|
|
|
|
|
} else if (type == "relu") {
|
|
|
|
|
return std::dynamic_pointer_cast<const VActKernel<T>>(
|
|
|
|
|
KernelPool::Instance().template Get<VReluKernel<T>>(n));
|
|
|
|
|
} else if (type == "tanh") {
|
|
|
|
|
return std::dynamic_pointer_cast<const VActKernel<T>>(
|
|
|
|
|
KernelPool::Instance().template Get<VTanhKernel<T>>(n));
|
|
|
|
|
} else if (type == "identity" || type == "") {
|
|
|
|
|
return std::dynamic_pointer_cast<const VActKernel<T>>(
|
|
|
|
|
KernelPool::Instance().template Get<VIdentityKernel<T>>(n));
|
|
|
|
|
}
|
|
|
|
|
PADDLE_THROW("Not support type: %s", type);
|
|
|
|
|
};
|
|
|
|
|
act_gate_3d_ = GetActKernel(act_gate, d * 3);
|
|
|
|
|
act_cand_d_ = GetActKernel(act_cand, d);
|
|
|
|
|
act_cell_d_ = GetActKernel(act_cell, d);
|
|
|
|
|
act_gate_d3_ = GetActKernel<T>(act_gate, d3_);
|
|
|
|
|
act_gate_d_ = GetActKernel<T>(act_gate, d);
|
|
|
|
|
act_cand_d_ = GetActKernel<T>(act_cand, d);
|
|
|
|
|
act_cell_d_ = GetActKernel<T>(act_cell, d);
|
|
|
|
|
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
|
|
|
|
|
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
@ -134,10 +138,10 @@ class LSTMKernelImpl : public LSTMKernel<T> {
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht,
|
|
|
|
|
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
|
|
|
|
|
T* checked) const override {
|
|
|
|
|
// gates: W_ch, W_ih, W_fh, W_oh
|
|
|
|
|
act_gate_3d_->Compute(gates + d_, gates + d_);
|
|
|
|
|
act_gate_d3_->Compute(gates + d_, gates + d_);
|
|
|
|
|
|
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated */
|
|
|
|
|
act_cand_d_->Compute(gates, gates);
|
|
|
|
@ -149,10 +153,21 @@ class LSTMKernelImpl : public LSTMKernel<T> {
|
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht);
|
|
|
|
|
}
|
|
|
|
|
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
|
|
|
|
|
/* C_t = igated * cgated*/
|
|
|
|
|
act_gate_d_->Compute(gates + d_, gates + d_);
|
|
|
|
|
act_cand_d_->Compute(gates, gates);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, ct);
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
act_gate_d_->Compute(gates + d3_, gates + d3_);
|
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
int d_, d2_, d3_;
|
|
|
|
|
std::shared_ptr<const VActKernel<T>> act_gate_3d_, act_cand_d_, act_cell_d_;
|
|
|
|
|
std::shared_ptr<const VActKernel<T>> act_gate_d3_, act_gate_d_, act_cand_d_,
|
|
|
|
|
act_cell_d_;
|
|
|
|
|
std::shared_ptr<const VMulKernel<T>> vmul_d_;
|
|
|
|
|
std::shared_ptr<const VAddKernel<T>> vadd_d_;
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
@ -163,8 +178,8 @@ class LSTMKernelImpl : public LSTMKernel<T> {
|
|
|
|
|
#define INTRI8_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
|
|
|
|
|
float* gates, const float* ct_1, float* ct, float* ht, float* checked) \
|
|
|
|
|
const { \
|
|
|
|
|
float* gates, const float* ct_1, float* ct, float* ht, \
|
|
|
|
|
const float* wp_data, float* checked) const { \
|
|
|
|
|
/* gates: W_ch, W_ih, W_fh, W_oh */ \
|
|
|
|
|
__m256 c, i, f, o; \
|
|
|
|
|
c = _mm256_loadu_ps(gates); \
|
|
|
|
@ -205,51 +220,56 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
|
|
|
|
|
d_ = d;
|
|
|
|
|
d2_ = d * 2;
|
|
|
|
|
d3_ = d * 3;
|
|
|
|
|
auto GetActKernel = [&](const std::string& type,
|
|
|
|
|
int n) -> std::shared_ptr<const VActKernel<T>> {
|
|
|
|
|
if (type == "sigmoid") {
|
|
|
|
|
return std::dynamic_pointer_cast<const VActKernel<T>>(
|
|
|
|
|
KernelPool::Instance().template Get<VSigmoidKernel<T>>(n));
|
|
|
|
|
} else if (type == "relu") {
|
|
|
|
|
return std::dynamic_pointer_cast<const VActKernel<T>>(
|
|
|
|
|
KernelPool::Instance().template Get<VReluKernel<T>>(n));
|
|
|
|
|
} else if (type == "tanh") {
|
|
|
|
|
return std::dynamic_pointer_cast<const VActKernel<T>>(
|
|
|
|
|
KernelPool::Instance().template Get<VTanhKernel<T>>(n));
|
|
|
|
|
} else if (type == "identity" || type == "") {
|
|
|
|
|
return std::dynamic_pointer_cast<const VActKernel<T>>(
|
|
|
|
|
KernelPool::Instance().template Get<VIdentityKernel<T>>(n));
|
|
|
|
|
}
|
|
|
|
|
PADDLE_THROW("Not support type: %s", type);
|
|
|
|
|
};
|
|
|
|
|
act_gate_3d_ = GetActKernel(act_gate, d * 3);
|
|
|
|
|
act_cand_d_ = GetActKernel(act_cand, d);
|
|
|
|
|
act_cell_d_ = GetActKernel(act_cell, d);
|
|
|
|
|
act_gate_d_ = GetActKernel<T>(act_gate, d);
|
|
|
|
|
act_cand_d_ = GetActKernel<T>(act_cand, d);
|
|
|
|
|
act_cell_d_ = GetActKernel<T>(act_cell, d);
|
|
|
|
|
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
|
|
|
|
|
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
|
|
|
|
|
vadd_d2_ = KernelPool::Instance().template Get<VAddKernel<T>>(d2_);
|
|
|
|
|
act_gate_d2_ = GetActKernel<T>(act_gate, d2_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht,
|
|
|
|
|
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
|
|
|
|
|
T* checked) const override {
|
|
|
|
|
// gates: W_ch, W_ih, W_fh, W_oh
|
|
|
|
|
act_gate_3d_->Compute(gates + d_, gates + d_);
|
|
|
|
|
|
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated */
|
|
|
|
|
/* get fgated and igated*/
|
|
|
|
|
vmul_d_->Compute(wp_data, ct_1, checked);
|
|
|
|
|
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_);
|
|
|
|
|
vadd_d2_->Compute(checked, gates + d_, gates + d_);
|
|
|
|
|
act_gate_d2_->Compute(gates + d_, gates + d_);
|
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated*/
|
|
|
|
|
act_cand_d_->Compute(gates, gates);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, gates + d_);
|
|
|
|
|
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_);
|
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d2_, ct);
|
|
|
|
|
/* get ogated*/
|
|
|
|
|
vmul_d_->Compute(wp_data + d2_, ct, gates + d_);
|
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
|
|
|
|
|
act_gate_d_->Compute(gates + d3_, gates + d3_);
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
|
|
|
|
|
/* C_t = igated * cgated*/
|
|
|
|
|
act_gate_d_->Compute(gates + d_, gates + d_);
|
|
|
|
|
act_cand_d_->Compute(gates, gates);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d_, ct);
|
|
|
|
|
/* get outgated, put W_oc * C_t on igated */
|
|
|
|
|
vmul_d_->Compute(wp_data + d2_, ct, gates + d_);
|
|
|
|
|
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_);
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
act_gate_d_->Compute(gates + d3_, gates + d3_);
|
|
|
|
|
act_cell_d_->Compute(ct, gates + d2_);
|
|
|
|
|
vmul_d_->Compute(gates + d2_, gates + d3_, ht);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
int d_, d2_, d3_;
|
|
|
|
|
std::shared_ptr<const VActKernel<T>> act_gate_3d_, act_cand_d_, act_cell_d_;
|
|
|
|
|
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_cand_d_,
|
|
|
|
|
act_cell_d_;
|
|
|
|
|
std::shared_ptr<const VMulKernel<T>> vmul_d_;
|
|
|
|
|
std::shared_ptr<const VAddKernel<T>> vadd_d_;
|
|
|
|
|
std::shared_ptr<const VAddKernel<T>> vadd_d_, vadd_d2_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
|
|
|
|
|