|
|
|
@ -23,6 +23,8 @@ namespace jit {
|
|
|
|
|
namespace more {
|
|
|
|
|
namespace mix {
|
|
|
|
|
|
|
|
|
|
using CPUPlace = platform::CPUPlace;
|
|
|
|
|
|
|
|
|
|
void VSigmoid(const T* x, T* y, int n) {
|
|
|
|
|
const float min = SIGMOID_THRESHOLD_MIN;
|
|
|
|
|
const float max = SIGMOID_THRESHOLD_MAX;
|
|
|
|
@ -30,7 +32,7 @@ void VSigmoid(const T* x, T* y, int n) {
|
|
|
|
|
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
|
|
|
|
|
y[i] = static_cast<T>(0) - y[i];
|
|
|
|
|
}
|
|
|
|
|
auto compute = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
auto compute = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
|
|
|
|
|
compute(y, y, n);
|
|
|
|
|
for (int i = 0; i < n; ++i) {
|
|
|
|
|
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
|
|
|
|
@ -39,9 +41,9 @@ void VSigmoid(const T* x, T* y, int n) {
|
|
|
|
|
|
|
|
|
|
void VTanh(const T* x, T* y, int n) {
|
|
|
|
|
const T a = 2, b = -1;
|
|
|
|
|
auto compute_scal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
auto compute_addbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
auto compute_sigmoid = Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
auto compute_scal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_addbias = KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_sigmoid = KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(n);
|
|
|
|
|
compute_scal(&a, x, y, n);
|
|
|
|
|
compute_sigmoid(y, y, n);
|
|
|
|
|
compute_scal(&a, y, y, n);
|
|
|
|
@ -49,16 +51,12 @@ void VTanh(const T* x, T* y, int n) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Softmax(const T* x, T* y, int n, int bs) {
|
|
|
|
|
auto compute_hmax =
|
|
|
|
|
KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_hsum =
|
|
|
|
|
KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_vscal =
|
|
|
|
|
KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_hmax = KernelFuncs<HMaxTuple<T>, CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_hsum = KernelFuncs<HSumTuple<T>, CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_vscal = KernelFuncs<VScalTuple<T>, CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_vaddbias =
|
|
|
|
|
KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_vexp =
|
|
|
|
|
KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
|
|
|
|
|
KernelFuncs<VAddBiasTuple<T>, CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_vexp = KernelFuncs<VExpTuple<T>, CPUPlace>::Cache().At(n);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < bs; ++i) {
|
|
|
|
|
T scalar;
|
|
|
|
@ -76,13 +74,13 @@ void Softmax(const T* x, T* y, int n, int bs) {
|
|
|
|
|
|
|
|
|
|
void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
|
|
|
|
|
if (type == kVSigmoid) {
|
|
|
|
|
return Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
return KernelFuncs<VSigmoidTuple<T>, CPUPlace>::Cache().At(d);
|
|
|
|
|
} else if (type == kVRelu) {
|
|
|
|
|
return Get<kVRelu, XYNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
return KernelFuncs<VReluTuple<T>, CPUPlace>::Cache().At(d);
|
|
|
|
|
} else if (type == kVTanh) {
|
|
|
|
|
return Get<kVTanh, XYNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
return KernelFuncs<VTanhTuple<T>, CPUPlace>::Cache().At(d);
|
|
|
|
|
} else if (type == kVIdentity) {
|
|
|
|
|
return Get<kVIdentity, XYNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
return KernelFuncs<VIdentityTuple<T>, CPUPlace>::Cache().At(d);
|
|
|
|
|
}
|
|
|
|
|
PADDLE_THROW("Not support type: %s", type);
|
|
|
|
|
return nullptr;
|
|
|
|
@ -98,9 +96,9 @@ void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
|
|
|
|
|
const int d = attr->d;
|
|
|
|
|
const int d2 = d * 2;
|
|
|
|
|
const int d3 = d * 3;
|
|
|
|
|
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
auto vadd_d2 = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d2);
|
|
|
|
|
auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
|
|
|
|
|
auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
|
|
|
|
|
auto vadd_d2 = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d2);
|
|
|
|
|
auto act_gate_d = getActFunc(attr->act_gate, d);
|
|
|
|
|
auto act_gate_d2 = getActFunc(attr->act_gate, d2);
|
|
|
|
|
auto act_gate_d3 = getActFunc(attr->act_gate, d3);
|
|
|
|
@ -140,8 +138,8 @@ void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
|
|
|
|
|
int d = attr->d;
|
|
|
|
|
int d2 = d * 2;
|
|
|
|
|
int d3 = d * 3;
|
|
|
|
|
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
|
|
|
|
|
auto vadd_d = KernelFuncs<VAddTuple<T>, CPUPlace>::Cache().At(d);
|
|
|
|
|
auto act_gate_d = getActFunc(attr->act_gate, d);
|
|
|
|
|
auto act_cand_d = getActFunc(attr->act_cand, d);
|
|
|
|
|
auto act_cell_d = getActFunc(attr->act_cell, d);
|
|
|
|
@ -169,7 +167,7 @@ void GRUH1(gru_t* step, const gru_attr_t* attr) {
|
|
|
|
|
int d2 = d * 2;
|
|
|
|
|
auto act_gate = getActFunc(attr->act_gate, d);
|
|
|
|
|
auto act_cand = getActFunc(attr->act_cand, d);
|
|
|
|
|
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
|
auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(d);
|
|
|
|
|
act_gate(gates, gates, d);
|
|
|
|
|
act_cand(gates + d2, gates + d2, d);
|
|
|
|
|
vmul_d(gates, gates + d2, ht, d);
|
|
|
|
@ -182,7 +180,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
|
|
|
|
|
T* ht = reinterpret_cast<T*>(step->ht);
|
|
|
|
|
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
|
|
|
|
|
auto act_gate = getActFunc(attr->act_gate, attr->d);
|
|
|
|
|
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(attr->d);
|
|
|
|
|
auto vmul_d = KernelFuncs<VMulTuple<T>, CPUPlace>::Cache().At(attr->d);
|
|
|
|
|
act_gate(gates + attr->d, gates + attr->d, attr->d);
|
|
|
|
|
vmul_d(ht_1, gates + attr->d, ht, attr->d);
|
|
|
|
|
}
|
|
|
|
@ -230,16 +228,16 @@ bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
|
|
|
|
|
|
|
|
|
|
namespace mix = paddle::operators::jit::more::mix;
|
|
|
|
|
|
|
|
|
|
#define REGISTER_MORE_KERNEL(key, func) \
|
|
|
|
|
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel)
|
|
|
|
|
|
|
|
|
|
REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid);
|
|
|
|
|
REGISTER_MORE_KERNEL(kVTanh, VTanh);
|
|
|
|
|
REGISTER_MORE_KERNEL(kSoftmax, Softmax);
|
|
|
|
|
REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt);
|
|
|
|
|
REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1);
|
|
|
|
|
REGISTER_MORE_KERNEL(kGRUH1, GRUH1);
|
|
|
|
|
REGISTER_MORE_KERNEL(kGRUHtPart1, GRUHtPart1);
|
|
|
|
|
REGISTER_MORE_KERNEL(kGRUHtPart2, GRUHtPart2);
|
|
|
|
|
#define REGISTER_MORE_KERNEL(func) \
|
|
|
|
|
REGISTER_JITKERNEL_MORE(k##func, mix, mix::func##Kernel)
|
|
|
|
|
|
|
|
|
|
REGISTER_MORE_KERNEL(VSigmoid);
|
|
|
|
|
REGISTER_MORE_KERNEL(VTanh);
|
|
|
|
|
REGISTER_MORE_KERNEL(Softmax);
|
|
|
|
|
REGISTER_MORE_KERNEL(LSTMCtHt);
|
|
|
|
|
REGISTER_MORE_KERNEL(LSTMC1H1);
|
|
|
|
|
REGISTER_MORE_KERNEL(GRUH1);
|
|
|
|
|
REGISTER_MORE_KERNEL(GRUHtPart1);
|
|
|
|
|
REGISTER_MORE_KERNEL(GRUHtPart2);
|
|
|
|
|
|
|
|
|
|
#undef REGISTER_MORE_KERNEL
|
|
|
|
|