|
|
|
@ -48,6 +48,27 @@ void VTanh(const T* x, T* y, int n) {
|
|
|
|
|
compute_addbias(&b, y, y, n);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Softmax(const T* x, T* y, int n, int bs) {
|
|
|
|
|
auto compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
auto compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
auto compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
auto compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
auto compute_vexp =
|
|
|
|
|
Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
for (int i = 0; i < bs; ++i) {
|
|
|
|
|
T scalar;
|
|
|
|
|
compute_hmax(x, &scalar, n);
|
|
|
|
|
scalar = static_cast<T>(0) - scalar;
|
|
|
|
|
compute_vaddbias(&scalar, x, y, n); // x - max
|
|
|
|
|
compute_vexp(y, y, n);
|
|
|
|
|
compute_hsum(y, &scalar, n);
|
|
|
|
|
scalar = static_cast<T>(1) / scalar;
|
|
|
|
|
compute_vscal(&scalar, y, y, n);
|
|
|
|
|
x += n;
|
|
|
|
|
y += n;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
|
|
|
|
|
if (type == kVSigmoid) {
|
|
|
|
|
return Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(d);
|
|
|
|
@ -184,6 +205,8 @@ bool VSigmoidKernel::UseMe(const int& d) const { return true; }
|
|
|
|
|
|
|
|
|
|
bool VTanhKernel::UseMe(const int& d) const { return true; }
|
|
|
|
|
|
|
|
|
|
bool SoftmaxKernel::UseMe(const int& d) const { return true; }
|
|
|
|
|
|
|
|
|
|
bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; }
|
|
|
|
|
|
|
|
|
|
bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; }
|
|
|
|
@ -207,6 +230,7 @@ namespace mix = paddle::operators::jit::more::mix;
|
|
|
|
|
|
|
|
|
|
REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid);
|
|
|
|
|
REGISTER_MORE_KERNEL(kVTanh, VTanh);
|
|
|
|
|
REGISTER_MORE_KERNEL(kSoftmax, Softmax);
|
|
|
|
|
REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt);
|
|
|
|
|
REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1);
|
|
|
|
|
REGISTER_MORE_KERNEL(kGRUH1, GRUH1);
|
|
|
|
|