|
|
|
@ -36,7 +36,7 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
|
|
|
|
|
.template Get<jitkernel::VAddReluKernel<T>>(N);
|
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
|
|
|
T* dst = Y + i * N;
|
|
|
|
|
vaddrelu->Compute(B, dst, dst);
|
|
|
|
|
vaddrelu->Compute(B, dst, dst, N);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
const auto& vadd = jitkernel::KernelPool::Instance()
|
|
|
|
@ -47,7 +47,7 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
|
|
|
|
|
#endif
|
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
|
|
|
T* dst = Y + i * N;
|
|
|
|
|
vadd->Compute(B, dst, dst);
|
|
|
|
|
vadd->Compute(B, dst, dst, N);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|