|
|
|
@ -23,7 +23,6 @@ namespace jit {
|
|
|
|
|
namespace more {
|
|
|
|
|
namespace mix {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void VSigmoid(const T* x, T* y, int n) {
|
|
|
|
|
const float min = SIGMOID_THRESHOLD_MIN;
|
|
|
|
|
const float max = SIGMOID_THRESHOLD_MAX;
|
|
|
|
@ -38,7 +37,6 @@ void VSigmoid(const T* x, T* y, int n) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void VTanh(const T* x, T* y, int n) {
|
|
|
|
|
const T a = 2, b = -1;
|
|
|
|
|
auto compute_scal = Get<vscal, AXYNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
@ -50,26 +48,151 @@ void VTanh(const T* x, T* y, int n) {
|
|
|
|
|
compute_addbias(&b, y, y, n);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
bool VSigmoidKernel<float>::UseMe(int d) const {
|
|
|
|
|
return true;
|
|
|
|
|
void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
|
|
|
|
|
if (type == vsigmoid) {
|
|
|
|
|
return Get<vsigmoid, XYNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
} else if (type == vrelu) {
|
|
|
|
|
return Get<vrelu, XYNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
} else if (type == vtanh) {
|
|
|
|
|
return Get<vtanh, XYNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
} else if (type == videntity) {
|
|
|
|
|
return Get<videntity, XYNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
}
|
|
|
|
|
PADDLE_THROW("Not support type: %s", type);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
bool VTanhKernel<float>::UseMe(int d) const {
|
|
|
|
|
return true;
|
|
|
|
|
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
|
|
|
|
|
T* gates = reinterpret_cast<T*>(step->gates);
|
|
|
|
|
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
|
|
|
|
|
T* ct = reinterpret_cast<T*>(step->ct);
|
|
|
|
|
T* ht = reinterpret_cast<T*>(step->ht);
|
|
|
|
|
const T* wp = reinterpret_cast<const T*>(step->wp);
|
|
|
|
|
T* checked = reinterpret_cast<T*>(step->checked);
|
|
|
|
|
const int d = attr->d;
|
|
|
|
|
const int d2 = d * 2;
|
|
|
|
|
const int d3 = d * 3;
|
|
|
|
|
auto vmul_d = Get<vmul, XYZNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
auto vadd_d = Get<vadd, XYZNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
auto vadd_d2 = Get<vadd, XYZNTuples<T>, platform::CPUPlace>(d2);
|
|
|
|
|
auto act_gate_d = getActFunc(attr->act_gate, d);
|
|
|
|
|
auto act_gate_d2 = getActFunc(attr->act_gate, d2);
|
|
|
|
|
auto act_gate_d3 = getActFunc(attr->act_gate, d2);
|
|
|
|
|
auto act_cand_d = getActFunc(attr->act_cand, d);
|
|
|
|
|
auto act_cell_d = getActFunc(attr->act_cell, d);
|
|
|
|
|
|
|
|
|
|
if (attr->use_peephole) {
|
|
|
|
|
vmul_d(wp, ct_1, checked, d);
|
|
|
|
|
vmul_d(wp + d, ct_1, checked + d, d);
|
|
|
|
|
vadd_d2(checked, gates + d, gates + d, d2);
|
|
|
|
|
act_gate_d2(gates + d, gates + d, d2);
|
|
|
|
|
} else {
|
|
|
|
|
act_gate_d3(gates + d, gates + d, d3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// C_t = C_t-1 * fgated + cand_gated * igated
|
|
|
|
|
act_cand_d(gates, gates, d);
|
|
|
|
|
vmul_d(gates, gates + d, gates + d, d);
|
|
|
|
|
vmul_d(ct_1, gates + d2, gates + d2, d);
|
|
|
|
|
vadd_d(gates + d, gates + d2, ct, d);
|
|
|
|
|
|
|
|
|
|
if (attr->use_peephole) {
|
|
|
|
|
// get ogated
|
|
|
|
|
vmul_d(wp + d2, ct, gates + d, d);
|
|
|
|
|
vadd_d(gates + d, gates + d3, gates + d3, d);
|
|
|
|
|
act_gate_d(gates + d3, gates + d3, d);
|
|
|
|
|
}
|
|
|
|
|
// H_t = act_cell(C_t) * ogated
|
|
|
|
|
act_cell_d(ct, gates + d2, d);
|
|
|
|
|
vmul_d(gates + d2, gates + d3, ht, d);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
|
|
|
|
|
template <> \
|
|
|
|
|
bool func##Kernel<double>::UseMe(int d) const { \
|
|
|
|
|
return true; \
|
|
|
|
|
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
|
|
|
|
|
T* gates = reinterpret_cast<T*>(step->gates);
|
|
|
|
|
T* ct = reinterpret_cast<T*>(step->ct);
|
|
|
|
|
T* ht = reinterpret_cast<T*>(step->ht);
|
|
|
|
|
int d = attr->d;
|
|
|
|
|
int d2 = d * 2;
|
|
|
|
|
int d3 = d * 3;
|
|
|
|
|
auto vmul_d = Get<vmul, XYZNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
auto vadd_d = Get<vadd, XYZNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
auto act_gate_d = getActFunc(attr->act_gate, d);
|
|
|
|
|
auto act_cand_d = getActFunc(attr->act_cand, d);
|
|
|
|
|
auto act_cell_d = getActFunc(attr->act_cell, d);
|
|
|
|
|
/* C_t = igated * cgated*/
|
|
|
|
|
act_gate_d(gates + d, gates + d, d);
|
|
|
|
|
act_cand_d(gates, gates, d);
|
|
|
|
|
vmul_d(gates, gates + d, ct, d);
|
|
|
|
|
if (attr->use_peephole) {
|
|
|
|
|
// get outgated, put W_oc * C_t on igated
|
|
|
|
|
const T* wp = reinterpret_cast<const T*>(step->wp);
|
|
|
|
|
vmul_d(wp + d2, ct, gates + d, d);
|
|
|
|
|
vadd_d(gates + d, gates + d3, gates + d3, d);
|
|
|
|
|
}
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
act_gate_d(gates + d3, gates + d3, d);
|
|
|
|
|
act_cell_d(ct, gates + d2, d);
|
|
|
|
|
vmul_d(gates + d2, gates + d3, ht, d);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute h1 without h0
|
|
|
|
|
void GRUH1(gru_t* step, const gru_attr_t* attr) {
|
|
|
|
|
T* gates = reinterpret_cast<T*>(step->gates);
|
|
|
|
|
T* ht = reinterpret_cast<T*>(step->ht);
|
|
|
|
|
int d = attr->d;
|
|
|
|
|
int d2 = d * 2;
|
|
|
|
|
auto act_gate = getActFunc(attr->act_gate, d);
|
|
|
|
|
auto act_cand = getActFunc(attr->act_cand, d);
|
|
|
|
|
auto vmul_d = Get<vmul, XYZNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
act_gate(gates, gates, d);
|
|
|
|
|
act_cand(gates + d2, gates + d2, d);
|
|
|
|
|
vmul_d(gates, gates + d2, ht, d);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute the first part of GRU: ht = act_gate(r) * ht_1
|
|
|
|
|
void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
|
|
|
|
|
// W: {W_update, W_reset; W_state}
|
|
|
|
|
T* gates = reinterpret_cast<T*>(step->gates);
|
|
|
|
|
T* ht = reinterpret_cast<T*>(step->ht);
|
|
|
|
|
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
|
|
|
|
|
auto act_gate = getActFunc(attr->act_gate, attr->d);
|
|
|
|
|
auto vmul_d = Get<vmul, XYZNTuples<T>, platform::CPUPlace>(attr->d);
|
|
|
|
|
act_gate(gates + attr->d, gates + attr->d, attr->d);
|
|
|
|
|
vmul_d(ht_1, gates + attr->d, ht, attr->d);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute the second part of GRU:
|
|
|
|
|
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
|
|
|
|
|
void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
|
|
|
|
|
T* gates = reinterpret_cast<T*>(step->gates);
|
|
|
|
|
T* ht = reinterpret_cast<T*>(step->ht);
|
|
|
|
|
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
|
|
|
|
|
int d = attr->d;
|
|
|
|
|
auto act_gate = getActFunc(attr->act_gate, d);
|
|
|
|
|
auto act_cand = getActFunc(attr->act_cand, d);
|
|
|
|
|
T* y = gates + d * 2;
|
|
|
|
|
act_gate(gates, gates, d);
|
|
|
|
|
act_cand(y, y, d);
|
|
|
|
|
// out = zt*ht~ + (1-zt)*ht_1
|
|
|
|
|
for (int i = 0; i < d; ++i) {
|
|
|
|
|
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): tuning me
|
|
|
|
|
bool VSigmoidKernel::UseMe(int d) const { return true; }
|
|
|
|
|
|
|
|
|
|
bool VTanhKernel::UseMe(int d) const { return true; }
|
|
|
|
|
|
|
|
|
|
bool LSTMCtHtKernel::UseMe(lstm_attr_t attr) const { return true; }
|
|
|
|
|
|
|
|
|
|
bool LSTMC1H1Kernel::UseMe(lstm_attr_t attr) const { return true; }
|
|
|
|
|
|
|
|
|
|
bool GRUH1Kernel::UseMe(gru_attr_t attr) const { return true; }
|
|
|
|
|
|
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
|
|
|
|
|
AWALYS_USE_ME_WITH_DOUBLE(VTanh);
|
|
|
|
|
bool GRUHtPart1Kernel::UseMe(gru_attr_t attr) const { return true; }
|
|
|
|
|
|
|
|
|
|
#undef AWALYS_USE_ME_WITH_DOUBLE
|
|
|
|
|
bool GRUHtPart2Kernel::UseMe(gru_attr_t attr) const { return true; }
|
|
|
|
|
|
|
|
|
|
} // namespace mix
|
|
|
|
|
} // namespace more
|
|
|
|
@ -79,11 +202,15 @@ AWALYS_USE_ME_WITH_DOUBLE(VTanh);
|
|
|
|
|
|
|
|
|
|
namespace mix = paddle::operators::jit::more::mix;
|
|
|
|
|
|
|
|
|
|
#define REGISTER_MORE_KERNEL(key, func) \
|
|
|
|
|
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel<float>, \
|
|
|
|
|
mix::func##Kernel<double>)
|
|
|
|
|
#define REGISTER_MORE_KERNEL(key, func) \
|
|
|
|
|
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel)
|
|
|
|
|
|
|
|
|
|
REGISTER_MORE_KERNEL(vsigmoid, VSigmoid);
|
|
|
|
|
REGISTER_MORE_KERNEL(vtanh, VTanh);
|
|
|
|
|
REGISTER_MORE_KERNEL(lstmctht, LSTMCtHt);
|
|
|
|
|
REGISTER_MORE_KERNEL(lstmc1h1, LSTMC1H1);
|
|
|
|
|
REGISTER_MORE_KERNEL(gruh1, GRUH1);
|
|
|
|
|
REGISTER_MORE_KERNEL(gruhtpart1, GRUHtPart1);
|
|
|
|
|
REGISTER_MORE_KERNEL(gruhtpart2, GRUHtPart2);
|
|
|
|
|
|
|
|
|
|
#undef REGISTER_MORE_KERNEL
|
|
|
|
|