Implement the GPU kernel of fc operator (#19687)
* Refine the codes related to fc op. * Add GPU implementation for fc functor. * Apply fc_fuse_pass in GPU inference. test=develop * Change the cmake for fc op. * Change PADDLE_ENFORCE to PADDLE_ENFORCE_EQ. * Add an attribute to set the activation type in fc_op. * Enhance the unittest of fc_op. test=develop * Remove the declaration of FCOpGrad back to the header file. test=develop * Set default value for newly added arguments in test_fc_op. test=developexpand_as_op_1
parent
22301115d0
commit
a65c728e5d
@ -0,0 +1,20 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fc_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
fc, ops::FCOpKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::FCOpKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,62 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/fc.h"
|
||||
#include "paddle/fluid/operators/jit/kernels.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T>
|
||||
class FCFunctor<platform::CPUDeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CPUDeviceContext& context, const int M,
|
||||
const int N, const int K, const T* X, const T* W, T* Y,
|
||||
const T* B = nullptr, bool relu = false) {
|
||||
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
|
||||
blas.MatMul(M, N, K, X, W, Y);
|
||||
if (B == NULL) {
|
||||
return;
|
||||
}
|
||||
if (relu) {
|
||||
auto compute =
|
||||
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache()
|
||||
.At(N);
|
||||
for (int i = 0; i < M; i++) {
|
||||
T* dst = Y + i * N;
|
||||
compute(B, dst, dst, N);
|
||||
}
|
||||
} else {
|
||||
auto compute =
|
||||
jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(
|
||||
N);
|
||||
#ifdef PADDLE_WITH_MKLML
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < M; i++) {
|
||||
T* dst = Y + i * N;
|
||||
compute(B, dst, dst, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class FCFunctor<platform::CPUDeviceContext, float>;
|
||||
template class FCFunctor<platform::CPUDeviceContext, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,73 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <algorithm>
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/fc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T, bool DoRelu>
|
||||
__global__ void InplaceAddReluKernel(const T* bias, T* data, int M, int N) {
|
||||
for (int i = blockIdx.x; i < M; i += gridDim.x) {
|
||||
int index = i * N + threadIdx.x;
|
||||
for (int j = threadIdx.x; j < N; j += blockDim.x) {
|
||||
T tmp = data[index] + bias[j];
|
||||
if (DoRelu) {
|
||||
data[index] = (tmp > 0) ? tmp : 0;
|
||||
} else {
|
||||
data[index] = tmp;
|
||||
}
|
||||
index += blockDim.x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class FCFunctor<platform::CUDADeviceContext, T> {
|
||||
public:
|
||||
void operator()(const platform::CUDADeviceContext& context, const int M,
|
||||
const int N, const int K, const T* X, const T* W, T* Y,
|
||||
const T* B = nullptr, bool relu = false) {
|
||||
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
|
||||
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X, K, W, N,
|
||||
static_cast<T>(0.0), Y, N);
|
||||
if (B == NULL) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int kThreadsPerBlock = 1024;
|
||||
int max_threads = context.GetMaxPhysicalThreadCount();
|
||||
int num_threads = std::min(kThreadsPerBlock, (((N + 31) >> 5) << 5));
|
||||
int num_blocks = std::max(max_threads / num_threads, 1);
|
||||
if (relu) {
|
||||
InplaceAddReluKernel<
|
||||
T, true><<<num_blocks, num_threads, 0, context.stream()>>>(B, Y, M,
|
||||
N);
|
||||
} else {
|
||||
InplaceAddReluKernel<
|
||||
T, false><<<num_blocks, num_threads, 0, context.stream()>>>(B, Y, M,
|
||||
N);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class FCFunctor<platform::CUDADeviceContext, float>;
|
||||
template class FCFunctor<platform::CUDADeviceContext, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,34 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class FCFunctor {
|
||||
public:
|
||||
void operator()(const DeviceContext& context, const int M, const int N,
|
||||
const int K, const T* X, const T* W, T* Y,
|
||||
const T* B = nullptr, bool relu = false);
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,55 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/operators/jit/kernels.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
|
||||
const int N, const int K, const T* X, const T* W, T* Y,
|
||||
const T* B = NULL, bool relu = false) {
|
||||
blas.MatMul(M, N, K, X, W, Y);
|
||||
if (B == NULL) {
|
||||
return;
|
||||
}
|
||||
if (relu) {
|
||||
auto compute =
|
||||
jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
|
||||
N);
|
||||
for (int i = 0; i < M; i++) {
|
||||
T* dst = Y + i * N;
|
||||
compute(B, dst, dst, N);
|
||||
}
|
||||
} else {
|
||||
auto compute =
|
||||
jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(N);
|
||||
#ifdef PADDLE_WITH_MKLML
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < M; i++) {
|
||||
T* dst = Y + i * N;
|
||||
compute(B, dst, dst, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue