|
|
|
@ -61,27 +61,20 @@ class FCFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
"When bias is NULL, relu can not be true."));
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (relu) {
|
|
|
|
|
auto compute =
|
|
|
|
|
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache()
|
|
|
|
|
.At(N);
|
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
|
|
|
T* dst = Y + i * N;
|
|
|
|
|
T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst;
|
|
|
|
|
compute(B, src, dst, N);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto compute =
|
|
|
|
|
jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(
|
|
|
|
|
N);
|
|
|
|
|
auto compute =
|
|
|
|
|
relu
|
|
|
|
|
? jit::KernelFuncs<jit::VAddReluTuple<T>,
|
|
|
|
|
platform::CPUPlace>::Cache()
|
|
|
|
|
.At(N)
|
|
|
|
|
: jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache()
|
|
|
|
|
.At(N);
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
#pragma omp parallel for
|
|
|
|
|
#endif
|
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
|
|
|
T* dst = Y + i * N;
|
|
|
|
|
T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst;
|
|
|
|
|
compute(B, src, dst, N);
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
|
|
|
T* dst = Y + i * N;
|
|
|
|
|
T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst;
|
|
|
|
|
compute(B, src, dst, N);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|