optimize fc jit (#21878)

test=develop
1.6.2
GaoWei8 6 years ago committed by Yiqun Liu
parent 879e3074ea
commit d4dda8628e

@ -61,27 +61,20 @@ class FCFunctor<platform::CPUDeviceContext, T> {
"When bias is NULL, relu can not be true.")); "When bias is NULL, relu can not be true."));
return; return;
} }
if (relu) { auto compute =
auto compute = relu
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache() ? jit::KernelFuncs<jit::VAddReluTuple<T>,
.At(N); platform::CPUPlace>::Cache()
for (int i = 0; i < M; i++) { .At(N)
T* dst = Y + i * N; : jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache()
T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst; .At(N);
compute(B, src, dst, N);
}
} else {
auto compute =
jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(
N);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst; T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N); compute(B, src, dst, N);
}
} }
} }
}; };

Loading…
Cancel
Save