From d962c2a997c86c004e4781238d384c8cc078171c Mon Sep 17 00:00:00 2001 From: qijun Date: Fri, 28 Jul 2017 16:22:12 +0800 Subject: [PATCH 1/3] fix bug in CUDADeviceContext --- cmake/flags.cmake | 2 +- paddle/platform/device_context.cc | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index 34fd348893..ef31c25203 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -153,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF) # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # So, don't set these flags here. -LIST(APPEND CUDA_NVCC_FLAGS -std=c++11) +LIST(APPEND CUDA_NVCC_FLAGS -std=c++11 --default-stream per-thread) LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math) if(CMAKE_BUILD_TYPE STREQUAL "Debug") diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 44afb5d4ee..5218d89d54 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -44,7 +44,19 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device() const { CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { SetDeviceId(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); - eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_)); + // TODO (qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly + // here will cause segment fault. We must implement a class derived from + // Eigen::StreamInterface, and reinitialize it with a cuda stream and a gpu id + // later. Please refer to the implementation of class EigenCudaStreamDevice + // in TensorFlow. + // + // We find that CUDA 7 introduces a new option, the per-thread default stream, + // that has two effects. Please refer to https://devblogs.nvidia.com/ + // parallelforall/gpu-pro-tip-cuda-7-streams-simplify-concurrency/ + // + // So, we decide to use default stream and add –default-stream per-thread nvcc + // flag. Than, two threads with two CUDADeviceContexts will run parallelly. + eigen_stream_.reset(new Eigen::CudaStreamDevice()); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); } From 5364b3944eff4ed9bab22f968b5fb2dc03bd14da Mon Sep 17 00:00:00 2001 From: qijun Date: Fri, 28 Jul 2017 17:32:09 +0800 Subject: [PATCH 2/3] use cuda default stream --- cmake/external/eigen.cmake | 11 +---------- paddle/framework/detail/tensor-inl.h | 12 ++++-------- paddle/framework/tensor_test.cc | 18 +++++++++--------- paddle/memory/memcpy.cc | 6 +++--- paddle/memory/memcpy.h | 2 +- paddle/platform/device_context.cc | 9 +-------- paddle/platform/device_context.h | 5 ----- 7 files changed, 19 insertions(+), 44 deletions(-) diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 3e6cedbb0d..f7483f6be9 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3) ExternalProject_Add( extern_eigen3 ${EXTERNAL_PROJECT_LOG_ARGS} - # for latest version, please get from official website - # URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz" - # URL_MD5 "1a47e78efe365a97de0c022d127607c3" - - # for no-ssl http support, please get from bazel's mirror - # URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz" - # URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7" - - # get from github mirror GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" - GIT_TAG "a46d2e7337c4656f00abe54a8115f6d76153a048" + GIT_TAG "master" PREFIX ${EIGEN_SOURCE_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/paddle/framework/detail/tensor-inl.h b/paddle/framework/detail/tensor-inl.h index 2acae1b0e2..78797f58d2 100644 --- a/paddle/framework/detail/tensor-inl.h +++ b/paddle/framework/detail/tensor-inl.h @@ -83,14 +83,13 @@ inline void Tensor::ShareDataWith(const Tensor& src) { template inline void Tensor::CopyFrom(const Tensor& src, - const platform::CPUDeviceContext& ctx) { + const platform::CPUPlace& dst_place) { src.check_memory_size(); Resize(src.dims()); auto src_place = src.holder_->place(); auto src_ptr = static_cast(src.data()); - auto dst_place = ctx.GetPlace(); auto dst_ptr = static_cast(mutable_data(dst_place)); auto size = product(src.dims_) * sizeof(T); @@ -110,26 +109,23 @@ inline void Tensor::CopyFrom(const Tensor& src, #ifndef PADDLE_ONLY_CPU template inline void Tensor::CopyFrom(const Tensor& src, - const platform::CUDADeviceContext& ctx) { + const platform::GPUPlace& dst_place) { src.check_memory_size(); Resize(src.dims()); auto src_place = src.holder_->place(); auto src_ptr = static_cast(src.data()); - auto dst_place = ctx.GetPlace(); auto dst_ptr = static_cast(mutable_data(dst_place)); auto size = product(src.dims_) * sizeof(T); if (platform::is_cpu_place(src_place)) { memory::Copy(boost::get(dst_place), dst_ptr, - boost::get(src_place), src_ptr, size, - ctx.stream()); + boost::get(src_place), src_ptr, size, 0); } else if (platform::is_gpu_place(src_place)) { memory::Copy(boost::get(dst_place), dst_ptr, - boost::get(src_place), src_ptr, size, - ctx.stream()); + boost::get(src_place), src_ptr, size, 0); } } #endif diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index fd7143cfaa..ef1cc10b84 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -198,8 +198,8 @@ TEST(Tensor, CopyFrom) { int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; memcpy(src_ptr, arr, 9 * sizeof(int)); - auto* cpu_ctx = new paddle::platform::CPUDeviceContext(); - dst_tensor.CopyFrom(src_tensor, *cpu_ctx); + auto cpu_place = new paddle::platform::CPUPlace(); + dst_tensor.CopyFrom(src_tensor, *cpu_place); const int* dst_ptr = dst_tensor.data(); ASSERT_NE(src_ptr, dst_ptr); @@ -208,7 +208,7 @@ TEST(Tensor, CopyFrom) { } Tensor slice_tensor = src_tensor.Slice(1, 2); - dst_tensor.CopyFrom(slice_tensor, *cpu_ctx); + dst_tensor.CopyFrom(slice_tensor, *cpu_place); const int* slice_ptr = slice_tensor.data(); dst_ptr = dst_tensor.data(); ASSERT_NE(dst_ptr, slice_ptr); @@ -228,12 +228,12 @@ TEST(Tensor, CopyFrom) { memcpy(src_ptr, arr, 9 * sizeof(int)); // CPU Tensor to GPU Tensor - auto gpu_ctx = new paddle::platform::CUDADeviceContext(0); - gpu_tensor.CopyFrom(src_tensor, *gpu_ctx); + auto gpu_place = new paddle::platform::GPUPlace(0); + gpu_tensor.CopyFrom(src_tensor, *gpu_place); // GPU Tensor to CPU Tensor - auto cpu_ctx = new paddle::platform::CPUDeviceContext(); - dst_tensor.CopyFrom(gpu_tensor, *cpu_ctx); + auto cpu_place = new paddle::platform::CPUPlace(); + dst_tensor.CopyFrom(gpu_tensor, *cpu_place); // Compare Tensors const int* dst_ptr = dst_tensor.data(); @@ -245,10 +245,10 @@ TEST(Tensor, CopyFrom) { Tensor slice_tensor = src_tensor.Slice(1, 2); // CPU Slice Tensor to GPU Tensor - gpu_tensor.CopyFrom(slice_tensor, *gpu_ctx); + gpu_tensor.CopyFrom(slice_tensor, *gpu_place); // GPU Tensor to CPU Tensor - dst_tensor.CopyFrom(gpu_tensor, *cpu_ctx); + dst_tensor.CopyFrom(gpu_tensor, *cpu_place); // Compare Slice Tensors const int* slice_ptr = slice_tensor.data(); diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc index aaab1142ca..2cc32dd8dd 100644 --- a/paddle/memory/memcpy.cc +++ b/paddle/memory/memcpy.cc @@ -34,7 +34,7 @@ void Copy(platform::CPUPlace dst_place, void* dst, platform::GPUPlace src_place, const void* src, size_t num, - cudaStream_t stream) { + cudaStream_t stream = 0) { platform::SetDeviceId(src_place.device); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); } @@ -44,7 +44,7 @@ void Copy(platform::GPUPlace dst_place, void* dst, platform::CPUPlace src_place, const void* src, size_t num, - cudaStream_t stream) { + cudaStream_t stream = 0) { platform::SetDeviceId(dst_place.device); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); } @@ -54,7 +54,7 @@ void Copy(platform::GPUPlace dst_place, void* dst, platform::GPUPlace src_place, const void* src, size_t num, - cudaStream_t stream) { + cudaStream_t stream = 0) { if (dst_place == src_place) { platform::SetDeviceId(src_place.device); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); diff --git a/paddle/memory/memcpy.h b/paddle/memory/memcpy.h index 2b9c0eada6..eb2647c617 100644 --- a/paddle/memory/memcpy.h +++ b/paddle/memory/memcpy.h @@ -51,7 +51,7 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); */ template void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, - cudaStream_t stream); + cudaStream_t stream = 0); #endif // PADDLE_ONLY_CPU diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 5218d89d54..b65c20006c 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -43,7 +43,6 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device() const { CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { SetDeviceId(place_.device); - PADDLE_ENFORCE(cudaStreamCreate(&stream_)); // TODO (qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly // here will cause segment fault. We must implement a class derived from // Eigen::StreamInterface, and reinitialize it with a cuda stream and a gpu id @@ -76,15 +75,12 @@ CUDADeviceContext::~CUDADeviceContext() { } eigen_stream_.reset(); eigen_device_.reset(); - PADDLE_ENFORCE(cudaStreamDestroy(stream_)); } Place CUDADeviceContext::GetPlace() const { return place_; } -cudaStream_t CUDADeviceContext::stream() const { return stream_; } - void CUDADeviceContext::Wait() const { - PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + PADDLE_ENFORCE(cudaStreamSynchronize(0)); } Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { @@ -95,7 +91,6 @@ cublasHandle_t CUDADeviceContext::cublas_handle() { if (!cublas_handle_) { SetDeviceId(place_.device); PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); - PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); } return cublas_handle_; } @@ -104,7 +99,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { if (!cudnn_handle_) { SetDeviceId(place_.device); PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); - PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); } return cudnn_handle_; } @@ -116,7 +110,6 @@ curandGenerator_t CUDADeviceContext::curand_generator() { CURAND_RNG_PSEUDO_DEFAULT)); PADDLE_ENFORCE( dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); - PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); } return curand_generator_; } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 57035b335f..2038fafe2e 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -61,9 +61,6 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Wait for all operations completion in the stream. */ void Wait() const; - /*! \brief Return CUDA stream in the device context. */ - cudaStream_t stream() const; - /*! \brief Return place in the device context. */ Place GetPlace() const override; @@ -91,8 +88,6 @@ class CUDADeviceContext : public DeviceContext { private: uint64_t seed_; - cudaStream_t stream_; - // clang-format off cudnnHandle_t cudnn_handle_ = nullptr; cublasHandle_t cublas_handle_ = nullptr; From 303fb789a550dc1b962af008198158e583918f7d Mon Sep 17 00:00:00 2001 From: qijun Date: Fri, 28 Jul 2017 09:47:45 +0000 Subject: [PATCH 3/3] refine tensor copy from --- paddle/framework/detail/tensor-inl.h | 34 ++++++++-------------------- paddle/framework/tensor.h | 9 +------- paddle/memory/memcpy.cc | 6 ++--- paddle/memory/memcpy.h | 2 +- 4 files changed, 15 insertions(+), 36 deletions(-) diff --git a/paddle/framework/detail/tensor-inl.h b/paddle/framework/detail/tensor-inl.h index 78797f58d2..e7ff09dd5c 100644 --- a/paddle/framework/detail/tensor-inl.h +++ b/paddle/framework/detail/tensor-inl.h @@ -83,7 +83,7 @@ inline void Tensor::ShareDataWith(const Tensor& src) { template inline void Tensor::CopyFrom(const Tensor& src, - const platform::CPUPlace& dst_place) { + const platform::Place& dst_place) { src.check_memory_size(); Resize(src.dims()); @@ -94,41 +94,27 @@ inline void Tensor::CopyFrom(const Tensor& src, auto size = product(src.dims_) * sizeof(T); - if (platform::is_cpu_place(src_place)) { + if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { memory::Copy(boost::get(dst_place), dst_ptr, boost::get(src_place), src_ptr, size); } #ifndef PADDLE_ONLY_CPU - else if (platform::is_gpu_place(src_place)) { + else if (platform::is_gpu_place(src_place) && + platform::is_cpu_place(dst_place)) { memory::Copy(boost::get(dst_place), dst_ptr, boost::get(src_place), src_ptr, size, 0); - } -#endif -} - -#ifndef PADDLE_ONLY_CPU -template -inline void Tensor::CopyFrom(const Tensor& src, - const platform::GPUPlace& dst_place) { - src.check_memory_size(); - Resize(src.dims()); - - auto src_place = src.holder_->place(); - auto src_ptr = static_cast(src.data()); - - auto dst_ptr = static_cast(mutable_data(dst_place)); - - auto size = product(src.dims_) * sizeof(T); - - if (platform::is_cpu_place(src_place)) { + } else if (platform::is_cpu_place(src_place) && + platform::is_gpu_place(dst_place)) { memory::Copy(boost::get(dst_place), dst_ptr, boost::get(src_place), src_ptr, size, 0); - } else if (platform::is_gpu_place(src_place)) { + } else if (platform::is_gpu_place(src_place) && + platform::is_gpu_place(dst_place)) { memory::Copy(boost::get(dst_place), dst_ptr, boost::get(src_place), src_ptr, size, 0); } -} + #endif +} template inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 039ab08374..76070f636b 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -94,14 +94,7 @@ class Tensor { * @note CopyFrom supports CPU <-> GPU, GPU <-> GPU. */ template - inline void CopyFrom(const Tensor& src, - const platform::CPUDeviceContext& ctx); - -#ifndef PADDLE_ONLY_CPU - template - inline void CopyFrom(const Tensor& src, - const platform::CUDADeviceContext& ctx); -#endif + inline void CopyFrom(const Tensor& src, const platform::Place& dst_place); /** * @brief Return the slice of the tensor. diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc index 2cc32dd8dd..aaab1142ca 100644 --- a/paddle/memory/memcpy.cc +++ b/paddle/memory/memcpy.cc @@ -34,7 +34,7 @@ void Copy(platform::CPUPlace dst_place, void* dst, platform::GPUPlace src_place, const void* src, size_t num, - cudaStream_t stream = 0) { + cudaStream_t stream) { platform::SetDeviceId(src_place.device); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); } @@ -44,7 +44,7 @@ void Copy(platform::GPUPlace dst_place, void* dst, platform::CPUPlace src_place, const void* src, size_t num, - cudaStream_t stream = 0) { + cudaStream_t stream) { platform::SetDeviceId(dst_place.device); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); } @@ -54,7 +54,7 @@ void Copy(platform::GPUPlace dst_place, void* dst, platform::GPUPlace src_place, const void* src, size_t num, - cudaStream_t stream = 0) { + cudaStream_t stream) { if (dst_place == src_place) { platform::SetDeviceId(src_place.device); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); diff --git a/paddle/memory/memcpy.h b/paddle/memory/memcpy.h index eb2647c617..2b9c0eada6 100644 --- a/paddle/memory/memcpy.h +++ b/paddle/memory/memcpy.h @@ -51,7 +51,7 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); */ template void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, - cudaStream_t stream = 0); + cudaStream_t stream); #endif // PADDLE_ONLY_CPU