|
|
|
@ -30,8 +30,7 @@ class FCFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
|
|
|
|
|
framework::Tensor Y1;
|
|
|
|
|
T* Y1_data = nullptr;
|
|
|
|
|
auto padding = N % 128 == 0 && K % 128 == 0;
|
|
|
|
|
if (padding) {
|
|
|
|
|
if (padding_weights) {
|
|
|
|
|
const int NN = N + 4;
|
|
|
|
|
const int KK = K + 4;
|
|
|
|
|
framework::Tensor X1;
|
|
|
|
@ -43,25 +42,13 @@ class FCFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
|
|
|
memcpy(X1_data + i * KK, X + i * K, K * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
framework::Tensor W1;
|
|
|
|
|
T* W1_data = nullptr;
|
|
|
|
|
if (!padding_weights) {
|
|
|
|
|
W1_data = W1.mutable_data<T>({(K + 4) * (N + 4)}, platform::CPUPlace());
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
#pragma omp parallel for
|
|
|
|
|
#endif
|
|
|
|
|
for (int i = 0; i < K; i++) {
|
|
|
|
|
memcpy(W1_data + i * NN, W + i * N, N * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X1_data, KK,
|
|
|
|
|
(padding_weights ? W : W1_data), NN, static_cast<T>(0.0),
|
|
|
|
|
Y1_data, NN);
|
|
|
|
|
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X1_data, KK, W, NN,
|
|
|
|
|
static_cast<T>(0.0), Y1_data, NN);
|
|
|
|
|
} else {
|
|
|
|
|
blas.MatMul(M, N, K, X, W, Y);
|
|
|
|
|
}
|
|
|
|
|
if (B == NULL) {
|
|
|
|
|
if (padding) {
|
|
|
|
|
if (padding_weights) {
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
#pragma omp parallel for
|
|
|
|
|
#endif
|
|
|
|
@ -80,7 +67,7 @@ class FCFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
.At(N);
|
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
|
|
|
T* dst = Y + i * N;
|
|
|
|
|
T* src = (padding) ? Y1_data + i * (N + 4) : dst;
|
|
|
|
|
T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst;
|
|
|
|
|
compute(B, src, dst, N);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
@ -92,7 +79,7 @@ class FCFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
#endif
|
|
|
|
|
for (int i = 0; i < M; i++) {
|
|
|
|
|
T* dst = Y + i * N;
|
|
|
|
|
T* src = (padding) ? Y1_data + i * (N + 4) : dst;
|
|
|
|
|
T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst;
|
|
|
|
|
compute(B, src, dst, N);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|