|
|
|
@ -49,49 +49,16 @@ void VTanh(const T* x, T* y, int n) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Softmax(const T* x, T* y, int n, int bs) {
|
|
|
|
|
typename XRNTuples<T>::func_type compute_hmax{nullptr};
|
|
|
|
|
typename XRNTuples<T>::func_type compute_hsum{nullptr};
|
|
|
|
|
typename AXYNTuples<T>::func_type compute_vscal{nullptr};
|
|
|
|
|
typename AXYNTuples<T>::func_type compute_vaddbias{nullptr};
|
|
|
|
|
typename XYNTuples<T>::func_type compute_vexp{nullptr};
|
|
|
|
|
|
|
|
|
|
if (!KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Has(n)) {
|
|
|
|
|
compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Insert(n, compute_hmax);
|
|
|
|
|
} else {
|
|
|
|
|
compute_hmax = KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().At(n);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Has(n)) {
|
|
|
|
|
compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Insert(n, compute_hsum);
|
|
|
|
|
} else {
|
|
|
|
|
compute_hsum = KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().At(n);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Has(n)) {
|
|
|
|
|
compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Insert(n,
|
|
|
|
|
compute_vscal);
|
|
|
|
|
} else {
|
|
|
|
|
compute_vscal = KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().At(n);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Has(n)) {
|
|
|
|
|
compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Insert(
|
|
|
|
|
n, compute_vaddbias);
|
|
|
|
|
} else {
|
|
|
|
|
compute_vaddbias =
|
|
|
|
|
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().At(n);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Has(n)) {
|
|
|
|
|
compute_vexp = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
|
|
|
|
|
KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Insert(n, compute_vexp);
|
|
|
|
|
} else {
|
|
|
|
|
compute_vexp = KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().At(n);
|
|
|
|
|
}
|
|
|
|
|
auto compute_hmax =
|
|
|
|
|
KernelFuncs<kHMax, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_hsum =
|
|
|
|
|
KernelFuncs<kHSum, XRNTuples<T>, platform::CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_vscal =
|
|
|
|
|
KernelFuncs<kVScal, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_vaddbias =
|
|
|
|
|
KernelFuncs<kVAddBias, AXYNTuples<T>, platform::CPUPlace>::Cache().At(n);
|
|
|
|
|
auto compute_vexp =
|
|
|
|
|
KernelFuncs<kVExp, XYNTuples<T>, platform::CPUPlace>::Cache().At(n);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < bs; ++i) {
|
|
|
|
|
T scalar;
|
|
|
|
|