|
|
|
@ -18,6 +18,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/operators/conv_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/assert.h"
|
|
|
|
|
#include "paddle/fluid/platform/cudnn_helper.h"
|
|
|
|
|
#include "paddle/fluid/platform/float16.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -133,7 +134,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
|
|
|
|
|
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
|
|
|
|
|
// ------------------- cudnn conv forward ---------------------
|
|
|
|
|
T alpha = 1.0f, beta = 0.0f;
|
|
|
|
|
T alpha = static_cast<T>(1.0f);
|
|
|
|
|
T beta = static_cast<T>(0.0f);
|
|
|
|
|
for (int i = 0; i < groups; i++) {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
|
|
|
|
|
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
|
|
|
|
@ -315,16 +317,18 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_KERNEL(conv2d, CUDNN, ::paddle::platform::CUDAPlace,
|
|
|
|
|
namespace plat = paddle::platform;
|
|
|
|
|
REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
|
|
|
|
|
paddle::operators::CUDNNConvOpKernel<float>,
|
|
|
|
|
paddle::operators::CUDNNConvOpKernel<double>);
|
|
|
|
|
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, ::paddle::platform::CUDAPlace,
|
|
|
|
|
paddle::operators::CUDNNConvOpKernel<double>,
|
|
|
|
|
paddle::operators::CUDNNConvOpKernel < plat::float16);
|
|
|
|
|
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
|
|
|
|
|
paddle::operators::CUDNNConvGradOpKernel<float>,
|
|
|
|
|
paddle::operators::CUDNNConvGradOpKernel<double>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_KERNEL(conv3d, CUDNN, ::paddle::platform::CUDAPlace,
|
|
|
|
|
REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
|
|
|
|
|
paddle::operators::CUDNNConvOpKernel<float>,
|
|
|
|
|
paddle::operators::CUDNNConvOpKernel<double>);
|
|
|
|
|
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, ::paddle::platform::CUDAPlace,
|
|
|
|
|
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace,
|
|
|
|
|
paddle::operators::CUDNNConvGradOpKernel<float>,
|
|
|
|
|
paddle::operators::CUDNNConvGradOpKernel<double>);
|
|
|
|
|