|
|
|
@ -30,15 +30,17 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (relu) {
|
|
|
|
|
auto compute =
|
|
|
|
|
jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(N);
|
|
|
|
|
auto compute = jit::KernelFuncs<jit::kVAddRelu, jit::XYZNTuples<T>,
|
|
|
|
|
platform::CPUPlace>::Cache()
|
|
|
|
|
.At(N);
|
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
|
|
|
T* dst = Y + i * N;
|
|
|
|
|
compute(B, dst, dst, N);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto compute =
|
|
|
|
|
jit::Get<jit::kVAdd, jit::XYZNTuples<T>, platform::CPUPlace>(N);
|
|
|
|
|
auto compute = jit::KernelFuncs<jit::kVAdd, jit::XYZNTuples<T>,
|
|
|
|
|
platform::CPUPlace>::Cache()
|
|
|
|
|
.At(N);
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
#pragma omp parallel for
|
|
|
|
|
#endif
|
|
|
|
|