shanyi15-patch-2
Kexin Zhao 7 years ago
parent bfbc25bdb8
commit 8ebfc153dd

@ -28,6 +28,8 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
static_cast<size_t>(1024) * 1024 * 1024;
@ -134,8 +136,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv forward ---------------------
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
beta = 0.0f;
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
@ -282,8 +283,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data ---------------------
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
beta = 0.0f;
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.

@ -24,6 +24,8 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor;
using DataLayout = platform::DataLayout;
using PoolingMode = platform::PoolingMode;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
template <typename T>
class PoolCUDNNOpKernel : public framework::OpKernel<T> {
@ -78,9 +80,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn pool algorithm ---------------------
auto handle = ctx.cuda_device_context().cudnn_handle();
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
beta = 0.0f;
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
PADDLE_ENFORCE(platform::dynload::cudnnPoolingForward(
handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta,
cudnn_output_desc, output_data));
@ -145,9 +145,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn pool algorithm ---------------------
auto handle = ctx.cuda_device_context().cudnn_handle();
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
beta = 0.0f;
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.

Loading…
Cancel
Save