From 272f3e6d433c4f2a702e5d181c43920881e3ee25 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 6 Nov 2017 21:30:08 +0800 Subject: [PATCH] refine get cuda context --- paddle/framework/operator.h | 7 +++---- paddle/operators/accuracy_op.cu | 7 ++----- paddle/operators/conv2d_transpose_cudnn_op.cu | 1 - paddle/operators/conv_cudnn_op.cu | 1 - paddle/operators/conv_shift_op.cu | 8 ++------ paddle/operators/cross_entropy_op.cu | 15 +++++--------- paddle/operators/lookup_table_op.cu | 20 ++++++++----------- paddle/operators/multiplex_op.cu | 8 ++------ paddle/operators/nccl_op.cu | 4 +--- 9 files changed, 23 insertions(+), 48 deletions(-) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 5c1989c26b..a1303a9098 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -298,11 +298,10 @@ class ExecutionContext { } #ifdef PADDLE_WITH_CUDA - const platform::CUDADeviceContext& cuda_device_context() const { + const inline platform::CUDADeviceContext& cuda_device_context() const { PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace())); - auto cuda_ctx = - reinterpret_cast(&device_context_); - return *cuda_ctx; + return *reinterpret_cast( + &device_context_); } #endif diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index a0483f367e..d0c4c0d25d 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -72,11 +72,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { } AccuracyCudaKernel<<< - 1, PADDLE_CUDA_NUM_THREADS, 0, - reinterpret_cast( - ctx.device_context()) - .stream()>>>(num_samples, infer_width, indices_data, label_data, - accuracy_data); + 1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>( + num_samples, infer_width, indices_data, label_data, accuracy_data); } }; diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cu b/paddle/operators/conv2d_transpose_cudnn_op.cu index 61fcfb3bd8..528e889a54 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cu +++ b/paddle/operators/conv2d_transpose_cudnn_op.cu @@ -27,7 +27,6 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using DataLayout = platform::DataLayout; -using CUDADeviceContext = platform::CUDADeviceContext; static constexpr size_t kConvCudnnWorkspaceLimitBytes = 1024 * 1024 * 1024; diff --git a/paddle/operators/conv_cudnn_op.cu b/paddle/operators/conv_cudnn_op.cu index e2eb157f40..074a6b1d62 100644 --- a/paddle/operators/conv_cudnn_op.cu +++ b/paddle/operators/conv_cudnn_op.cu @@ -27,7 +27,6 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using DataLayout = platform::DataLayout; -using CUDADeviceContext = platform::CUDADeviceContext; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024; diff --git a/paddle/operators/conv_shift_op.cu b/paddle/operators/conv_shift_op.cu index 145e966fe9..74ed1b0ed3 100644 --- a/paddle/operators/conv_shift_op.cu +++ b/paddle/operators/conv_shift_op.cu @@ -130,9 +130,7 @@ class ConvShiftKernel : public framework::OpKernel { dim3 grid_dim(num_x_blocks, batch_size); - auto stream = reinterpret_cast( - context.device_context()) - .stream(); + auto stream = context.cuda_device_context().stream(); conv_shift_forward<<>>( x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size); @@ -159,9 +157,7 @@ class ConvShiftGradKernel int y_width = Y->dims()[1]; int y_half_width = (y_width - 1) / 2; - auto stream = reinterpret_cast( - context.device_context()) - .stream(); + auto stream = context.cuda_device_context().stream(); const int x_per_block = 256; int num_x_blocks = div_up(x_width, x_per_block); diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index a523cb6fce..530b319a44 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -82,24 +82,19 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { int block = 512; int grid = (batch_size * class_num + block - 1) / block; + auto stream = ctx.cuda_device_context().stream(); if (ctx.Attr("soft_label")) { auto* label_data = label->data(); - SoftCrossEntropyGradientKernel<<< - grid, block, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>(dx_data, dy_data, x_data, label_data, - batch_size, class_num); + SoftCrossEntropyGradientKernel<<>>( + dx_data, dy_data, x_data, label_data, batch_size, class_num); } else { math::SetConstant functor; functor(ctx.device_context(), dx, 0); auto* label_data = label->data(); grid = (batch_size + block - 1) / block; - CrossEntropyGradientKernel<<< - grid, block, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>(dx_data, dy_data, x_data, label_data, - batch_size, class_num); + CrossEntropyGradientKernel<<>>( + dx_data, dy_data, x_data, label_data, batch_size, class_num); } } }; diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu index c7ba172066..10d66e5ff4 100644 --- a/paddle/operators/lookup_table_op.cu +++ b/paddle/operators/lookup_table_op.cu @@ -74,10 +74,9 @@ class LookupTableCUDAKernel : public framework::OpKernel { dim3 threads(128, 8); dim3 grids(8, 1); - LookupTable<<< - grids, threads, 0, reinterpret_cast( - context.device_context()) - .stream()>>>(output, table, ids, N, K, D); + LookupTable<<>>( + output, table, ids, N, K, D); } }; @@ -95,9 +94,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { auto* ids_data = ids->data(); auto ids_dim = ids->dims(); - auto stream = reinterpret_cast( - context.device_context()) - .stream(); + auto stream = context.cuda_device_context().stream(); // copy GPU memory to CPU pinned memory framework::Vector new_rows; new_rows.resize(ids_dim[0]); @@ -136,11 +133,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { dim3 threads(128, 8); dim3 grids(8, 1); - LookupTableGrad<<( - context.device_context()) - .stream()>>>(d_table, d_output, ids, N, K, D); + LookupTableGrad< + T, 128, 8, + 8><<>>( + d_table, d_output, ids, N, K, D); } } }; diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index 143a14fef5..7adc7df164 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -35,9 +35,7 @@ class MultiplexGPUKernel : public framework::OpKernel { Tensor index_t_cpu; index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), ctx.device_context()); auto* index = index_t_cpu.data(); - auto stream = reinterpret_cast( - ctx.device_context()) - .stream(); + auto stream = ctx.cuda_device_context().stream(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { int32_t k = index[i]; @@ -73,9 +71,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel { index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), ctx.device_context()); auto* index = index_t_cpu.data(); - auto stream = reinterpret_cast( - ctx.device_context()) - .stream(); + auto stream = ctx.device_context().stream(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { size_t k = static_cast(index[i]); diff --git a/paddle/operators/nccl_op.cu b/paddle/operators/nccl_op.cu index 86dee8ee8e..4f0a2a79ed 100644 --- a/paddle/operators/nccl_op.cu +++ b/paddle/operators/nccl_op.cu @@ -64,9 +64,7 @@ class NCCLAllReduceKernel : public framework::OpKernel { auto* comm = ctx.Input("Communicator"); - auto stream = reinterpret_cast( - ctx.device_context()) - .stream(); + auto stream = ctx.cuda_device_context().stream(); // device id int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId();