From 14d2c3990fe74e063d30f21540019802e1b36194 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 12 Jul 2017 10:41:45 +0800 Subject: [PATCH 01/21] split device_context --- paddle/platform/CMakeLists.txt | 10 +- paddle/platform/cuda_device_context.h | 148 +++++++++++++++++++++++++ paddle/platform/device_context.cc | 13 --- paddle/platform/device_context.h | 131 +--------------------- paddle/platform/device_context_test.cc | 2 +- paddle/platform/dynload/CMakeLists.txt | 2 +- 6 files changed, 158 insertions(+), 148 deletions(-) create mode 100644 paddle/platform/cuda_device_context.h delete mode 100644 paddle/platform/device_context.cc diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 7a198aec6c..4e34e8d02c 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -1,14 +1,8 @@ add_subdirectory(dynload) -nv_test(cuda_test SRCS cuda_test.cu DEPS dyload_cuda) +nv_test(cuda_test SRCS cuda_test.cu DEPS dynload_cuda) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) -IF(WITH_GPU) - set(GPU_CTX_DEPS dyload_cuda dynamic_loader ) -ELSE() - set(GPU_CTX_DEPS) -ENDIF() -cc_library(device_context SRCS device_context.cc DEPS place eigen3 ${GPU_CTX_DEPS}) -nv_test(device_context_test SRCS device_context_test.cc DEPS device_context glog gflags) +nv_test(device_context_test SRCS device_context_test.cc DEPS place eigen3 dynload_cuda) diff --git a/paddle/platform/cuda_device_context.h b/paddle/platform/cuda_device_context.h new file mode 100644 index 0000000000..e0d79631c5 --- /dev/null +++ b/paddle/platform/cuda_device_context.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/enforce.h" +#include "paddle/platform/cuda.h" +#include "paddle/platform/dynload/cublas.h" +#include "paddle/platform/dynload/cudnn.h" +#include "paddle/platform/dynload/curand.h" +#define EIGEN_USE_GPU +#include "paddle/platform/place.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace platform { + +class GPUPlaceGuard { + public: + explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { + if (previous_ != new_place) { + paddle::platform::SetDeviceId(new_place.device); + } + } + + ~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); } + + private: + GPUPlace previous_; +}; + +class CUDADeviceContext : public DeviceContext { + public: + explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { + GPUPlaceGuard guard(gpu_place_); + paddle::platform::throw_on_error(cudaStreamCreate(&stream_), + "cudaStreamCreate failed"); + eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); + eigen_device_ = new Eigen::GpuDevice(eigen_stream_); + } + + void Wait() { + paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), + "cudaStreamSynchronize failed"); + } + + cudaStream_t stream() { return stream_; } + + Eigen::GpuDevice eigen_device() { return *eigen_device_; } + + cublasHandle_t cublas_handle() { + if (!blas_handle_) { + GPUPlaceGuard guard(gpu_place_); + PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) == + CUBLAS_STATUS_SUCCESS, + "cublasCreate failed"); + PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream( + blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, + "cublasSetStream failed"); + } + return blas_handle_; + } + + cudnnHandle_t cudnn_handle() { + if (!dnn_handle_) { + GPUPlaceGuard guard(gpu_place_); + PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) == + CUDNN_STATUS_SUCCESS, + "cudnnCreate failed"); + PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream( + dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, + "cudnnSetStream failed"); + } + return dnn_handle_; + } + + curandGenerator_t curand_generator() { + if (!rand_generator_) { + GPUPlaceGuard guard(gpu_place_); + PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( + &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == + CURAND_STATUS_SUCCESS, + "curandCreateGenerator failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed( + rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS, + "curandSetPseudoRandomGeneratorSeed failed"); + PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream( + rand_generator_, stream_) == CURAND_STATUS_SUCCESS, + "curandSetStream failed"); + } + return rand_generator_; + } + + ~CUDADeviceContext() { + Wait(); + if (blas_handle_) { + PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) == + CUBLAS_STATUS_SUCCESS, + "cublasDestroy failed"); + } + + if (dnn_handle_) { + PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) == + CUDNN_STATUS_SUCCESS, + "cudnnDestroy failed"); + } + + if (rand_generator_) { + PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator( + rand_generator_) == CURAND_STATUS_SUCCESS, + "curandDestroyGenerator failed"); + } + + delete eigen_stream_; + delete eigen_device_; + + paddle::platform::throw_on_error(cudaStreamDestroy(stream_), + "cudaStreamDestroy failed"); + } + + private: + GPUPlace gpu_place_; + cudaStream_t stream_; + + Eigen::CudaStreamDevice* eigen_stream_; + Eigen::GpuDevice* eigen_device_; + + cublasHandle_t blas_handle_{nullptr}; + + cudnnHandle_t dnn_handle_{nullptr}; + + int random_seed_; + curandGenerator_t rand_generator_{nullptr}; +}; +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc deleted file mode 100644 index a2dea2ed1e..0000000000 --- a/paddle/platform/device_context.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include - -namespace paddle { -namespace platform { -namespace dynload { -namespace dummy { -// Make DeviceContext A library. -int DUMMY_VAR_FOR_DEV_CTX = 0; - -} // namespace dummy -} // namespace dynload -} // namespace platform -} // namespace paddle \ No newline at end of file diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 160eb4e120..f30c147126 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -13,16 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once - #include "paddle/framework/enforce.h" -#ifndef PADDLE_ONLY_CPU -#include "paddle/platform/cuda.h" -#include "paddle/platform/dynload/cublas.h" -#include "paddle/platform/dynload/cudnn.h" -#include "paddle/platform/dynload/curand.h" -#define EIGEN_USE_GPU -#endif -#include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { @@ -33,128 +24,18 @@ class DeviceContext { virtual ~DeviceContext() {} }; -class CPUDeviceContext : public DeviceContext {}; - -#ifndef PADDLE_ONLY_CPU - -class GPUPlaceGuard { +class CPUDeviceContext : public DeviceContext { public: - explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { - if (previous_ != new_place) { - paddle::platform::SetDeviceId(new_place.device); + Eigen::DefaultDevice eigen_handle() { + if (!eigen_handle_) { + eigen_handle_ = new Eigen::DefaultDevice(); } + return *eigen_handle_; } - ~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); } - private: - GPUPlace previous_; + Eigen::DefaultDevice* eigen_handle_{nullptr}; }; -class CUDADeviceContext : public DeviceContext { - public: - explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { - GPUPlaceGuard guard(gpu_place_); - paddle::platform::throw_on_error(cudaStreamCreate(&stream_), - "cudaStreamCreate failed"); - eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); - eigen_device_ = new Eigen::GpuDevice(eigen_stream_); - } - - void Wait() { - paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), - "cudaStreamSynchronize failed"); - } - - cudaStream_t stream() { return stream_; } - - Eigen::GpuDevice eigen_device() { return *eigen_device_; } - - cublasHandle_t cublas_handle() { - if (!blas_handle_) { - GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) == - CUBLAS_STATUS_SUCCESS, - "cublasCreate failed"); - PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream( - blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, - "cublasSetStream failed"); - } - return blas_handle_; - } - - cudnnHandle_t cudnn_handle() { - if (!dnn_handle_) { - GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) == - CUDNN_STATUS_SUCCESS, - "cudnnCreate failed"); - PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream( - dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, - "cudnnSetStream failed"); - } - return dnn_handle_; - } - - curandGenerator_t curand_generator() { - if (!rand_generator_) { - GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( - &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == - CURAND_STATUS_SUCCESS, - "curandCreateGenerator failed"); - PADDLE_ENFORCE( - paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed( - rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS, - "curandSetPseudoRandomGeneratorSeed failed"); - PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream( - rand_generator_, stream_) == CURAND_STATUS_SUCCESS, - "curandSetStream failed"); - } - return rand_generator_; - } - - ~CUDADeviceContext() { - Wait(); - if (blas_handle_) { - PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) == - CUBLAS_STATUS_SUCCESS, - "cublasDestroy failed"); - } - - if (dnn_handle_) { - PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) == - CUDNN_STATUS_SUCCESS, - "cudnnDestroy failed"); - } - - if (rand_generator_) { - PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator( - rand_generator_) == CURAND_STATUS_SUCCESS, - "curandDestroyGenerator failed"); - } - - delete eigen_stream_; - delete eigen_device_; - - paddle::platform::throw_on_error(cudaStreamDestroy(stream_), - "cudaStreamDestroy failed"); - } - - private: - GPUPlace gpu_place_; - cudaStream_t stream_; - - Eigen::CudaStreamDevice* eigen_stream_; - Eigen::GpuDevice* eigen_device_; - - cublasHandle_t blas_handle_{nullptr}; - - cudnnHandle_t dnn_handle_{nullptr}; - - int random_seed_; - curandGenerator_t rand_generator_{nullptr}; -}; -#endif } // namespace platform } // namespace paddle diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 61be4a307d..cc81e9e789 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/platform/device_context.h" #include "gtest/gtest.h" +#include "paddle/platform/cuda_device_context.h" TEST(CUDADeviceContext, Init) { int count = paddle::platform::GetDeviceCount(); diff --git a/paddle/platform/dynload/CMakeLists.txt b/paddle/platform/dynload/CMakeLists.txt index 4a8866b3d3..d205ead845 100644 --- a/paddle/platform/dynload/CMakeLists.txt +++ b/paddle/platform/dynload/CMakeLists.txt @@ -1,2 +1,2 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) -nv_library(dyload_cuda SRCS cublas.cc cudnn.cc curand.cc) +nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc) From 2dbe60e489221e9883bf08e48efb10cffaabe62b Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 12 Jul 2017 11:08:26 +0800 Subject: [PATCH 02/21] Remove Dim::contiguous and Dim::contiguous_strides Paddle's data block is row-major order, while Dim::contiguous and Dim::contiguous_strides are based on column-order. So remove them to prevent misuse. --- paddle/framework/dim.h | 48 ------------------------------------ paddle/framework/dim_test.cu | 28 --------------------- 2 files changed, 76 deletions(-) diff --git a/paddle/framework/dim.h b/paddle/framework/dim.h index bcde291d12..883fdc55eb 100644 --- a/paddle/framework/dim.h +++ b/paddle/framework/dim.h @@ -266,29 +266,6 @@ HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) { return ((0 <= idx.head) && (idx.head < size.head)); } -/** - * \brief Check if a size and a stride create a Fortran order contiguous - * block of memory. - */ -template -HOST bool contiguous(const Dim& size, const Dim& stride, int mul = 1) { - if (product(size) == 0) return true; - int contiguous_stride = get<0>(size) == 1 ? 0 : mul; - return (get<0>(stride) == contiguous_stride && - contiguous(size.tail, stride.tail, mul * get<0>(size))); -} - -///\cond HIDDEN -// Base case of contiguous, check the nth stride is the size of -// the prefix multiply of n-1 dims. -template <> -inline bool contiguous(const Dim<1>& size, const Dim<1>& stride, int mul) { - if (get<0>(size) == 0) return true; - int contiguous_stride = get<0>(size) == 1 ? 0 : mul; - return get<0>(stride) == contiguous_stride; -} -///\endcond - /** * \brief Compute exclusive prefix-multiply of a Dim. */ @@ -306,31 +283,6 @@ HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) { } ///\endcond -/** - * \brief Calculate strides of a contiguous array of the given size - * - * Sets the stride for any dimension with an extent of 1 to 0. - * \param size Dim object containing the size of the array. - * \param base The base stride to use. - * \return Dim object the same size as \p size with the strides. - */ -template -HOSTDEVICE Dim contiguous_strides(const Dim& size, int base = 1) { - int stride = size.head == 1 ? 0 : base; - return Dim(stride, contiguous_strides(size.tail, base * size.head)); -} - -///\cond HIDDEN - -// Base case of contiguous_strides -template <> -HOSTDEVICE inline Dim<1> contiguous_strides(const Dim<1>& size, int base) { - int stride = size.head == 1 ? 0 : base; - return Dim<1>(stride); -} - -///\endcond - /** * Add two dimensions together */ diff --git a/paddle/framework/dim_test.cu b/paddle/framework/dim_test.cu index 809bf04826..0521741519 100644 --- a/paddle/framework/dim_test.cu +++ b/paddle/framework/dim_test.cu @@ -58,24 +58,6 @@ TEST(Dim, Equality) { EXPECT_EQ(paddle::framework::get<1>(c), 3); EXPECT_EQ(paddle::framework::get<2>(c), 12); - // contiguous_strides - c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 1, 10)); - EXPECT_EQ(paddle::framework::get<0>(c), 1); - EXPECT_EQ(paddle::framework::get<1>(c), 0); - EXPECT_EQ(paddle::framework::get<2>(c), 10); - c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 10, 1)); - EXPECT_EQ(paddle::framework::get<0>(c), 1); - EXPECT_EQ(paddle::framework::get<1>(c), 10); - EXPECT_EQ(paddle::framework::get<2>(c), 0); - c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(1, 10, 10)); - EXPECT_EQ(paddle::framework::get<0>(c), 0); - EXPECT_EQ(paddle::framework::get<1>(c), 1); - EXPECT_EQ(paddle::framework::get<2>(c), 10); - c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(2, 3, 4)); - EXPECT_EQ(paddle::framework::get<0>(c), 1); - EXPECT_EQ(paddle::framework::get<1>(c), 2); - EXPECT_EQ(paddle::framework::get<2>(c), 6); - // generate from an index auto size = paddle::framework::make_dim(4, 5, 2); c = paddle::framework::Dim<3>(14, size); @@ -101,16 +83,6 @@ TEST(Dim, Bool) { EXPECT_TRUE(a == a); EXPECT_FALSE(a == b); EXPECT_TRUE(a == c); - - // contiguous check - int x = 4, y = 5, z = 2; - paddle::framework::Dim<3> sizef(x, y, z); - paddle::framework::Dim<3> stridea(1, x, x*y); - paddle::framework::Dim<3> strideb(2, 2*x, 2*x*y); - paddle::framework::Dim<3> stridec(1, x, 2*x*y); - EXPECT_TRUE(paddle::framework::contiguous(sizef, stridea)); - EXPECT_FALSE(paddle::framework::contiguous(sizef, strideb)); - EXPECT_FALSE(paddle::framework::contiguous(sizef, stridec)); } TEST(Dim, Print) { From 8f5a9fd9a7297007dc259114c23d986c8ee4e06a Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 12 Jul 2017 11:20:14 +0800 Subject: [PATCH 03/21] fix gpu build error --- paddle/platform/CMakeLists.txt | 2 +- paddle/platform/cuda_device_context.h | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 4e34e8d02c..5e2f203555 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -5,4 +5,4 @@ nv_test(cuda_test SRCS cuda_test.cu DEPS dynload_cuda) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) -nv_test(device_context_test SRCS device_context_test.cc DEPS place eigen3 dynload_cuda) +nv_test(device_context_test SRCS device_context_test.cc DEPS dynload_cuda dynamic_loader eigen3 place) diff --git a/paddle/platform/cuda_device_context.h b/paddle/platform/cuda_device_context.h index e0d79631c5..0ba1f802a6 100644 --- a/paddle/platform/cuda_device_context.h +++ b/paddle/platform/cuda_device_context.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/platform/dynload/curand.h" #define EIGEN_USE_GPU #include "paddle/platform/place.h" +#include "paddle/platform/device_context.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { From b5a8d5b4b46c36dadc1a17b66c72984930e76305 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 12 Jul 2017 11:25:32 +0800 Subject: [PATCH 04/21] remove unused deps --- paddle/platform/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 5e2f203555..e93592cc4c 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -1,6 +1,6 @@ add_subdirectory(dynload) -nv_test(cuda_test SRCS cuda_test.cu DEPS dynload_cuda) +nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) From 6bbc2944aec25e11921bb98c93440a4e29bc3967 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 12 Jul 2017 12:10:27 +0800 Subject: [PATCH 05/21] fix code style --- paddle/platform/cuda_device_context.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/platform/cuda_device_context.h b/paddle/platform/cuda_device_context.h index 0ba1f802a6..69415fe615 100644 --- a/paddle/platform/cuda_device_context.h +++ b/paddle/platform/cuda_device_context.h @@ -20,8 +20,8 @@ limitations under the License. */ #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" #define EIGEN_USE_GPU -#include "paddle/platform/place.h" #include "paddle/platform/device_context.h" +#include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { From ef5f9debc61ce4f6b3142fedbf85a118a34731eb Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 12 Jul 2017 13:51:04 +0800 Subject: [PATCH 06/21] refine device_context --- paddle/platform/CMakeLists.txt | 1 + .../{cuda_device_context.h => cuda_device.h} | 13 +++--- paddle/platform/cuda_device_test.cc | 33 +++++++++++++++ paddle/platform/device.h | 41 +++++++++++++++++++ paddle/platform/device_context.h | 23 ++++------- paddle/platform/device_context_test.cc | 23 +++++------ 6 files changed, 102 insertions(+), 32 deletions(-) rename paddle/platform/{cuda_device_context.h => cuda_device.h} (94%) create mode 100644 paddle/platform/cuda_device_test.cc create mode 100644 paddle/platform/device.h diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index e93592cc4c..d40e49b546 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -5,4 +5,5 @@ nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) +nv_test(cuda_device_test SRCS cuda_device_test.cc DEPS dynload_cuda dynamic_loader eigen3 place) nv_test(device_context_test SRCS device_context_test.cc DEPS dynload_cuda dynamic_loader eigen3 place) diff --git a/paddle/platform/cuda_device_context.h b/paddle/platform/cuda_device.h similarity index 94% rename from paddle/platform/cuda_device_context.h rename to paddle/platform/cuda_device.h index 69415fe615..cbb69d1cc5 100644 --- a/paddle/platform/cuda_device_context.h +++ b/paddle/platform/cuda_device.h @@ -20,10 +20,12 @@ limitations under the License. */ #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" #define EIGEN_USE_GPU -#include "paddle/platform/device_context.h" +#include "paddle/platform/device.h" #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" +using DEVICE_GPU = Eigen::GpuDevice; + namespace paddle { namespace platform { @@ -41,9 +43,10 @@ class GPUPlaceGuard { GPUPlace previous_; }; -class CUDADeviceContext : public DeviceContext { +template <> +class Device { public: - explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { + explicit Device(const GPUPlace gpu_place) : gpu_place_(gpu_place) { GPUPlaceGuard guard(gpu_place_); paddle::platform::throw_on_error(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); @@ -58,7 +61,7 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream() { return stream_; } - Eigen::GpuDevice eigen_device() { return *eigen_device_; } + DEVICE_GPU eigen_device() { return *eigen_device_; } cublasHandle_t cublas_handle() { if (!blas_handle_) { @@ -136,7 +139,7 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream_; Eigen::CudaStreamDevice* eigen_stream_; - Eigen::GpuDevice* eigen_device_; + DEVICE_GPU* eigen_device_; cublasHandle_t blas_handle_{nullptr}; diff --git a/paddle/platform/cuda_device_test.cc b/paddle/platform/cuda_device_test.cc new file mode 100644 index 0000000000..ea647be876 --- /dev/null +++ b/paddle/platform/cuda_device_test.cc @@ -0,0 +1,33 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/cuda_device.h" +#include "gtest/gtest.h" + +TEST(Device, Init) { + int count = paddle::platform::GetDeviceCount(); + for (int i = 0; i < count; i++) { + paddle::platform::Device* device = + new paddle::platform::Device(i); + Eigen::GpuDevice gpu_device = device->eigen_device(); + ASSERT_NE(nullptr, gpu_device.stream()); + cudnnHandle_t cudnn_handle = device->cudnn_handle(); + ASSERT_NE(nullptr, cudnn_handle); + cublasHandle_t cublas_handle = device->cublas_handle(); + ASSERT_NE(nullptr, cublas_handle); + curandGenerator_t curand_handle = device->curand_generator(); + ASSERT_NE(nullptr, curand_handle); + delete device; + } +} diff --git a/paddle/platform/device.h b/paddle/platform/device.h new file mode 100644 index 0000000000..9ae41cbcb0 --- /dev/null +++ b/paddle/platform/device.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "unsupported/Eigen/CXX11/Tensor" + +using DEVICE_CPU = Eigen::DefaultDevice; + +namespace paddle { +namespace platform { + +template +class Device; + +template <> +class Device { + public: + DEVICE_CPU eigen_handle() { + if (!eigen_handle_) { + eigen_handle_ = new Eigen::DefaultDevice(); + } + return *eigen_handle_; + } + + private: + DEVICE_CPU* eigen_handle_{nullptr}; +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index f30c147126..8b0bac6280 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -14,27 +14,22 @@ limitations under the License. */ #pragma once #include "paddle/framework/enforce.h" +#include "paddle/platform/device.h" #include "unsupported/Eigen/CXX11/Tensor" +#ifndef PADDLE_ONLY_CPU +#include "paddle/platform/cuda_device.h" +#endif namespace paddle { namespace platform { -class DeviceContext { - public: - virtual ~DeviceContext() {} -}; +struct DeviceContext { + void* device_context{nullptr}; -class CPUDeviceContext : public DeviceContext { - public: - Eigen::DefaultDevice eigen_handle() { - if (!eigen_handle_) { - eigen_handle_ = new Eigen::DefaultDevice(); - } - return *eigen_handle_; + template + inline paddle::platform::Device* device_context() { + return static_cast*>(device_context); } - - private: - Eigen::DefaultDevice* eigen_handle_{nullptr}; }; } // namespace platform diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index cc81e9e789..ab8a6d8195 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -12,22 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/platform/device_context.h" #include "gtest/gtest.h" -#include "paddle/platform/cuda_device_context.h" -TEST(CUDADeviceContext, Init) { +TEST(DeviceContext, Init) { int count = paddle::platform::GetDeviceCount(); for (int i = 0; i < count; i++) { - paddle::platform::CUDADeviceContext* device_context = - new paddle::platform::CUDADeviceContext(i); - Eigen::GpuDevice gpu_device = device_context->eigen_device(); + paddle::platform::Device* device = + new paddle::platform::Device(i); + paddle::platform::DeviceContext context; + context.device_context = device; + Eigen::GpuDevice gpu_device = + context.device_context->eigen_device(); ASSERT_NE(nullptr, gpu_device.stream()); - cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); - ASSERT_NE(nullptr, cudnn_handle); - cublasHandle_t cublas_handle = device_context->cublas_handle(); - ASSERT_NE(nullptr, cublas_handle); - curandGenerator_t curand_handle = device_context->curand_generator(); - ASSERT_NE(nullptr, curand_handle); - delete device_context; + delete device; } -} +} \ No newline at end of file From 0ff819207230ac345efefc0a37a3883e81d43c02 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 12 Jul 2017 14:02:57 +0800 Subject: [PATCH 07/21] Add OperatorWithKernel class * User can register OpKernel to its Ops. The OpKernelMap saved in OperatorWithKernel. Each Op which inherits OperatorWithKernel will use `OpKernel::Compute` instead of Run. --- paddle/CMakeLists.txt | 1 - paddle/framework/op_registry_test.cc | 33 ++++---- paddle/framework/operator.cc | 8 -- paddle/framework/operator.h | 117 ++++++++++++++++++--------- paddle/framework/operator_test.cc | 39 ++++----- paddle/operators/.clang-format | 5 -- paddle/operators/CMakeLists.txt | 0 paddle/operators/demo_op.h | 59 -------------- paddle/platform/device_context.h | 18 ++++- 9 files changed, 127 insertions(+), 153 deletions(-) delete mode 100644 paddle/operators/.clang-format delete mode 100644 paddle/operators/CMakeLists.txt delete mode 100644 paddle/operators/demo_op.h diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 2c1eb7521d..58a35564f8 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -15,7 +15,6 @@ if(Boost_FOUND) add_subdirectory(memory) add_subdirectory(platform) add_subdirectory(framework) - add_subdirectory(operators) add_subdirectory(pybind) endif() diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index c4baafc2ae..f5d45a80bb 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -1,17 +1,15 @@ #include "paddle/framework/op_registry.h" #include -#include "paddle/framework/operator.h" -#include "paddle/operators/demo_op.h" using namespace paddle::framework; namespace paddle { namespace framework { -class CosineOp : public OperatorWithKernel { +class CosineOp : public OperatorBase { public: - void Run(const OpRunContext* context) const override { - printf("%s\n", DebugString().c_str()); - } + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override {} + void InferShape(const std::shared_ptr& scope) const override {} }; class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -30,12 +28,13 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) -class MyTestOp : public OperatorWithKernel { +class MyTestOp : public OperatorBase { + public: + void InferShape(const std::shared_ptr& scope) const override {} + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override {} + public: - void Run(const OpRunContext* ctx) const override { - printf("%s\n", DebugString().c_str()); - printf("test_attr = %d\n", ctx->op_->GetAttr("test_attr")); - } }; class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) { paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); - auto dev_ctx = DeviceContext(); - op->Run(scope, &dev_ctx); + paddle::platform::CPUDeviceContext dev_ctx; + op->Run(scope, dev_ctx); float scale_get = op->GetAttr("scale"); ASSERT_EQ(scale_get, scale); } @@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) { paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); - auto dev_ctx = DeviceContext(); - op->Run(scope, &dev_ctx); + paddle::platform::CPUDeviceContext dev_ctx; + op->Run(scope, dev_ctx); ASSERT_EQ(op->GetAttr("scale"), 1.0); } @@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) { attr->set_i(4); paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - auto dev_ctx = DeviceContext(); + paddle::platform::CPUDeviceContext dev_ctx; auto scope = std::make_shared(); - op->Run(scope, &dev_ctx); + op->Run(scope, dev_ctx); int test_attr = op->GetAttr("test_attr"); ASSERT_EQ(test_attr, 4); } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 3db3706e47..8f7adff8b3 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const { return ss.str(); } -const Variable* OpRunContext::Input(int index) const { - return scope_->GetVariable(op_->inputs_[index]); -} - -Variable* OpRunContext::Output(int index) const { - return scope_->GetVariable(op_->outputs_[index]); -} - } // namespace framework } // namespace paddle \ No newline at end of file diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 6570d58698..0ce422e007 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -14,44 +14,22 @@ limitations under the License. */ #pragma once +#include +#include +#include +#include +#include +#include #include #include #include #include -#include "paddle/framework/attr_checker.h" -#include "paddle/framework/op_desc.pb.h" -#include "paddle/framework/scope.h" -#include "paddle/utils/Error.h" - namespace paddle { namespace framework { class OperatorBase; -class DeviceContext {}; - -/** - * OpRunContext is the only parameter of Operator's Run function. - * Run will get input/output variables, state such as momentum and - * device resource such as CUDA stream, cublas handle, etc. from - * OpRunContext. User should construct it before run the Operator. - */ -class OpRunContext { - public: - OpRunContext(const OperatorBase* op, const std::shared_ptr scope, - const DeviceContext* device_context) - : op_(op), scope_(scope), device_context_(device_context) {} - - const Variable* Input(int index) const; - Variable* Output(int index) const; - - public: - const OperatorBase* op_; - const std::shared_ptr scope_; - const DeviceContext* device_context_; -}; - /** * OperatorBase has the basic element that Net will call to do computation. * Only CreateOperator from OpRegistry will new Operator directly. User @@ -77,7 +55,10 @@ class OperatorBase { /// Net will call this function to Run an op. virtual void Run(const std::shared_ptr& scope, - const DeviceContext* dev_ctx) const = 0; + const platform::DeviceContext& dev_ctx) const = 0; + + protected: + std::string Type() const { return desc_.type(); } public: OpDesc desc_; @@ -86,22 +67,84 @@ class OperatorBase { AttributeMap attrs_; }; +class OpKernel { + public: + /** + * KernelContext is the only parameter of Kernel Run function. + * Run will get input/output variables, state such as momentum and + * device resource such as CUDA stream, cublas handle, etc. from + * KernelContext. User should construct it before run the Operator. + */ + class KernelContext { + public: + KernelContext(const OperatorBase* op, const std::shared_ptr& scope, + const platform::DeviceContext& device_context) + : op_(*op), scope_(scope), device_context_(device_context) {} + + const Variable* Input(int index) const { + return scope_->GetVariable(op_.inputs_[index]); + } + + Variable* Output(int index) const { + return scope_->GetVariable(op_.outputs_[index]); + } + + const OperatorBase& op_; + const std::shared_ptr& scope_; + const platform::DeviceContext& device_context_; + }; + + virtual void Compute(const KernelContext& context) const = 0; + + virtual ~OpKernel() {} +}; + class OperatorWithKernel : public OperatorBase { public: - virtual ~OperatorWithKernel() {} + struct OpKernelKey { + platform::Place place_; - virtual void InferShape(const std::shared_ptr& scope) const {} + OpKernelKey() = default; + OpKernelKey(const platform::DeviceContext& dev_ctx) { + place_ = dev_ctx.GetPlace(); + } + + bool operator==(const OpKernelKey& o) const { return place_ == o.place_; } + }; + + struct OpKernelHash { + std::hash hash_; + size_t operator()(const OpKernelKey& key) const { + return hash_(platform::is_gpu_place(key.place_)); + } + }; + + using OpKernelMap = + std::unordered_map, OpKernelHash>; void Run(const std::shared_ptr& scope, - const DeviceContext* dev_ctx) const { - OpRunContext op_ctx(this, scope, dev_ctx); - Run(&op_ctx); + const platform::DeviceContext& dev_ctx) const final { + auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx)); + opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx)); } - /// when implement an Op, your should implement this function. - /// this function should be moved to OpKernel later - virtual void Run(const OpRunContext* context) const = 0; + static std::unordered_map& + AllOpKernels() { + static std::unordered_map g_all_op_kernels; + return g_all_op_kernels; + }; }; } // namespace framework } // namespace paddle + +#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \ + struct __op_kernel_register__##type##__ { \ + __op_kernel_register__##type##__() { \ + ::paddle::framework::OperatorWithKernel::OpKernelKey key; \ + key.place_ = PlaceType(); \ + ::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \ + .reset(new KernelType()); \ + } \ + }; \ + static __op_kernel_register__##type##__ __reg_kernel_##type##__ diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 48808dabb2..86f45f108a 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -19,17 +19,15 @@ limitations under the License. */ namespace paddle { namespace framework { -class OperatorTest : public OperatorWithKernel { +class OperatorTest : public OperatorBase { public: - void Run(const OpRunContext* ctx) const override { - float scale = ctx->op_->GetAttr("scale"); - PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized"); - PADDLE_ENFORCE(ctx->Output(0) == nullptr, - "Output(1) should not initialized"); - auto output1 = ctx->scope_->CreateVariable("output1"); - PADDLE_ENFORCE(output1 != nullptr, "should create output1 from scope"); - printf("get attr %s = %f\n", "scale", scale); - printf("%s\n", DebugString().c_str()); + void InferShape(const std::shared_ptr& scope) const override {} + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override { + float scale = GetAttr("scale"); + ASSERT_NEAR(scale, 3.14, 1e-5); + ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); + ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); } }; @@ -49,31 +47,26 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator) -TEST(OperatorBase, DebugString) { +TEST(OperatorBase, all) { OpDesc op_desc; op_desc.set_type("test_operator"); - std::vector inputs = {"IN1", "IN2"}; - for (auto& input : inputs) { - op_desc.add_inputs(input); - } - std::vector outputs = {"OUT1", "OUT2"}; - for (auto& output : outputs) { - op_desc.add_outputs(output); - } + *op_desc.mutable_inputs()->Add() = "IN1"; + *op_desc.mutable_outputs()->Add() = "OUT1"; auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); attr->set_type(paddle::framework::AttrType::FLOAT); float scale = 3.14; attr->set_f(scale); - DeviceContext device_context; + platform::CPUDeviceContext device_context; auto scope = std::make_shared(); OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); - ASSERT_EQ(op->inputs_, inputs); - ASSERT_EQ(op->outputs_, outputs); ASSERT_EQ(op->GetAttr("scale"), scale); - op->Run(scope, &device_context); + scope->CreateVariable("OUT1"); + op->Run(scope, device_context); + std::cout << op->DebugString() << std::endl; + delete op; } } // namespace framework diff --git a/paddle/operators/.clang-format b/paddle/operators/.clang-format deleted file mode 100644 index 29282dc87e..0000000000 --- a/paddle/operators/.clang-format +++ /dev/null @@ -1,5 +0,0 @@ ---- -Language: Cpp -BasedOnStyle: Google -Standard: Cpp11 -... diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/paddle/operators/demo_op.h b/paddle/operators/demo_op.h deleted file mode 100644 index d0b7420b4e..0000000000 --- a/paddle/operators/demo_op.h +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include "paddle/framework/op_registry.h" - -using namespace paddle::framework; - -namespace paddle { -namespace operators { - -class CosineOp : public OperatorWithKernel { - public: - void Run(const OpRunContext *context) const override { - printf("%s\n", DebugString().c_str()); - } -}; - -class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { - public: - CosineOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); - AddAttr("scale", "scale of cosine op") - .SetDefault(1.0) - .LargerThan(0.0); - AddType("cos"); - AddComment("This is cos op"); - } -}; - -REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) - -class MyTestOp : public OperatorWithKernel { - public: - void Run(const OpRunContext *context) const override { - printf("%s\n", DebugString().c_str()); - } -}; - -class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { - public: - MyTestOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); - auto my_checker = [](int i) { - PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); - }; - AddAttr("test_attr", "a simple test attribute") - .AddCustomChecker(my_checker); - AddType("my_test_op"); - AddComment("This is my_test op"); - } -}; - -REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) - -} // namespace operators -} // namespace operators diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 160eb4e120..e3c2cd2647 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -22,8 +22,8 @@ limitations under the License. */ #include "paddle/platform/dynload/curand.h" #define EIGEN_USE_GPU #endif -#include "paddle/platform/place.h" -#include "unsupported/Eigen/CXX11/Tensor" +#include +#include namespace paddle { namespace platform { @@ -31,9 +31,16 @@ namespace platform { class DeviceContext { public: virtual ~DeviceContext() {} + virtual Place GetPlace() const = 0; }; -class CPUDeviceContext : public DeviceContext {}; +class CPUDeviceContext : public DeviceContext { + public: + Place GetPlace() const override { + Place retv = CPUPlace(); + return retv; + } +}; #ifndef PADDLE_ONLY_CPU @@ -61,6 +68,11 @@ class CUDADeviceContext : public DeviceContext { eigen_device_ = new Eigen::GpuDevice(eigen_stream_); } + Place GetPlace() const override { + Place retv = GPUPlace(); + return retv; + } + void Wait() { paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), "cudaStreamSynchronize failed"); From 4d336d9063451a7568863b249ac53fe7de8bbaa8 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 12 Jul 2017 15:03:44 +0800 Subject: [PATCH 08/21] follow comments --- .../{cuda_device.h => cuda_device_context.h} | 15 ++++--- paddle/platform/cuda_device_test.cc | 33 --------------- paddle/platform/device.h | 41 ------------------- paddle/platform/device_context.h | 34 +++++++++++---- paddle/platform/device_context_test.cc | 33 +++++++++++---- 5 files changed, 59 insertions(+), 97 deletions(-) rename paddle/platform/{cuda_device.h => cuda_device_context.h} (94%) delete mode 100644 paddle/platform/cuda_device_test.cc delete mode 100644 paddle/platform/device.h diff --git a/paddle/platform/cuda_device.h b/paddle/platform/cuda_device_context.h similarity index 94% rename from paddle/platform/cuda_device.h rename to paddle/platform/cuda_device_context.h index cbb69d1cc5..420159fb2c 100644 --- a/paddle/platform/cuda_device.h +++ b/paddle/platform/cuda_device_context.h @@ -20,7 +20,6 @@ limitations under the License. */ #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" #define EIGEN_USE_GPU -#include "paddle/platform/device.h" #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" @@ -29,6 +28,13 @@ using DEVICE_GPU = Eigen::GpuDevice; namespace paddle { namespace platform { +class CUDADeviceContext; + +template <> +DEVICE_GPU DeviceContext::get_eigen_device() { + return static_cast(this)->eigen_handle(); +} + class GPUPlaceGuard { public: explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { @@ -43,8 +49,7 @@ class GPUPlaceGuard { GPUPlace previous_; }; -template <> -class Device { +class CUDADeviceContext : public DeviceContext { public: explicit Device(const GPUPlace gpu_place) : gpu_place_(gpu_place) { GPUPlaceGuard guard(gpu_place_); @@ -61,7 +66,7 @@ class Device { cudaStream_t stream() { return stream_; } - DEVICE_GPU eigen_device() { return *eigen_device_; } + Eigen::GpuDevice eigen_device() { return *eigen_device_; } cublasHandle_t cublas_handle() { if (!blas_handle_) { @@ -139,7 +144,7 @@ class Device { cudaStream_t stream_; Eigen::CudaStreamDevice* eigen_stream_; - DEVICE_GPU* eigen_device_; + Eigen::GpuDevice* eigen_device_; cublasHandle_t blas_handle_{nullptr}; diff --git a/paddle/platform/cuda_device_test.cc b/paddle/platform/cuda_device_test.cc deleted file mode 100644 index ea647be876..0000000000 --- a/paddle/platform/cuda_device_test.cc +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/platform/cuda_device.h" -#include "gtest/gtest.h" - -TEST(Device, Init) { - int count = paddle::platform::GetDeviceCount(); - for (int i = 0; i < count; i++) { - paddle::platform::Device* device = - new paddle::platform::Device(i); - Eigen::GpuDevice gpu_device = device->eigen_device(); - ASSERT_NE(nullptr, gpu_device.stream()); - cudnnHandle_t cudnn_handle = device->cudnn_handle(); - ASSERT_NE(nullptr, cudnn_handle); - cublasHandle_t cublas_handle = device->cublas_handle(); - ASSERT_NE(nullptr, cublas_handle); - curandGenerator_t curand_handle = device->curand_generator(); - ASSERT_NE(nullptr, curand_handle); - delete device; - } -} diff --git a/paddle/platform/device.h b/paddle/platform/device.h deleted file mode 100644 index 9ae41cbcb0..0000000000 --- a/paddle/platform/device.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "unsupported/Eigen/CXX11/Tensor" - -using DEVICE_CPU = Eigen::DefaultDevice; - -namespace paddle { -namespace platform { - -template -class Device; - -template <> -class Device { - public: - DEVICE_CPU eigen_handle() { - if (!eigen_handle_) { - eigen_handle_ = new Eigen::DefaultDevice(); - } - return *eigen_handle_; - } - - private: - DEVICE_CPU* eigen_handle_{nullptr}; -}; - -} // namespace platform -} // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 8b0bac6280..11a05702cd 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -13,23 +13,39 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/enforce.h" -#include "paddle/platform/device.h" #include "unsupported/Eigen/CXX11/Tensor" -#ifndef PADDLE_ONLY_CPU -#include "paddle/platform/cuda_device.h" -#endif + +using DEVICE_CPU = Eigen::DefaultDevice; namespace paddle { namespace platform { -struct DeviceContext { - void* device_context{nullptr}; +class CPUDeviceContext; + +class DeviceContext { + public: + virtual ~DeviceContext() {} template - inline paddle::platform::Device* device_context() { - return static_cast*>(device_context); + DeviceType get_eigen_device(); +}; + +template <> +DEVICE_CPU DeviceContext::get_eigen_device() { + return static_cast(this)->eigen_handle(); +} + +class CPUDeviceContext : public DeviceContext { + public: + Eigen::DefaultDevice eigen_handle() { + if (!eigen_handle_) { + eigen_handle_ = new Eigen::DefaultDevice(); + } + return *eigen_handle_; } + + private: + Eigen::DefaultDevice* eigen_handle_{nullptr}; }; } // namespace platform diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index ab8a6d8195..8390e97b15 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -12,19 +12,34 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/platform/device_context.h" #include "gtest/gtest.h" +#include "paddle/platform/cuda_device.h" -TEST(DeviceContext, Init) { +TEST(Device, Init) { int count = paddle::platform::GetDeviceCount(); for (int i = 0; i < count; i++) { - paddle::platform::Device* device = - new paddle::platform::Device(i); - paddle::platform::DeviceContext context; - context.device_context = device; + paddle::platform::DeviceContext* device_context = + new paddle::platform::CUDADeviceContext(i); Eigen::GpuDevice gpu_device = - context.device_context->eigen_device(); + device_context->get_eigen_device(); ASSERT_NE(nullptr, gpu_device.stream()); - delete device; + delete device_context; } -} \ No newline at end of file +} + +TEST(Device, CUDADeviceContext) { + int count = paddle::platform::GetDeviceCount(); + for (int i = 0; i < count; i++) { + paddle::platform::CUDADeviceContext* device_context = + new paddle::platform::CUDADeviceContext(i); + Eigen::GpuDevice gpu_device = device_context->eigen_device(); + ASSERT_NE(nullptr, gpu_device.stream()); + cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); + ASSERT_NE(nullptr, cudnn_handle); + cublasHandle_t cublas_handle = device_context->cublas_handle(); + ASSERT_NE(nullptr, cublas_handle); + curandGenerator_t curand_handle = device_context->curand_generator(); + ASSERT_NE(nullptr, curand_handle); + delete device_context; + } +} From e0ea87c99d242ea19f23301bd97492e47cacf231 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Wed, 12 Jul 2017 15:38:14 +0800 Subject: [PATCH 09/21] fix pybind compile question --- paddle/pybind/pybind.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 55aebc59ec..f9f87acf15 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include @@ -43,4 +44,4 @@ All parameter, weight, gradient are variables in Paddle. py::return_value_policy::reference); return m.ptr(); -} \ No newline at end of file +} From 8ee50a35d408634c817d3da849a15217e57dcba1 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 12 Jul 2017 07:50:08 +0000 Subject: [PATCH 10/21] fix gpu build error --- paddle/platform/CMakeLists.txt | 1 - paddle/platform/cuda_device_context.h | 15 +++++++-------- paddle/platform/device_context.h | 24 +++++++++++------------- paddle/platform/device_context_test.cc | 5 +++-- 4 files changed, 21 insertions(+), 24 deletions(-) diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index d40e49b546..e93592cc4c 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -5,5 +5,4 @@ nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) -nv_test(cuda_device_test SRCS cuda_device_test.cc DEPS dynload_cuda dynamic_loader eigen3 place) nv_test(device_context_test SRCS device_context_test.cc DEPS dynload_cuda dynamic_loader eigen3 place) diff --git a/paddle/platform/cuda_device_context.h b/paddle/platform/cuda_device_context.h index c38dcd5a61..8a9d15e8a8 100644 --- a/paddle/platform/cuda_device_context.h +++ b/paddle/platform/cuda_device_context.h @@ -20,19 +20,13 @@ limitations under the License. */ #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" #define EIGEN_USE_GPU +#include "paddle/platform/device_context.h" #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace platform { -class CUDADeviceContext; - -template <> -Eigen::GpuDevice DeviceContext::get_eigen_device() { - return static_cast(this)->eigen_handle(); -} - class GPUPlaceGuard { public: explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { @@ -49,7 +43,7 @@ class GPUPlaceGuard { class CUDADeviceContext : public DeviceContext { public: - explicit Device(const GPUPlace gpu_place) : gpu_place_(gpu_place) { + explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { GPUPlaceGuard guard(gpu_place_); paddle::platform::throw_on_error(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); @@ -156,5 +150,10 @@ class CUDADeviceContext : public DeviceContext { int random_seed_; curandGenerator_t rand_generator_{nullptr}; }; + +template <> +Eigen::GpuDevice DeviceContext::get_eigen_device() { + return dynamic_cast(this)->eigen_device(); +} } // namespace platform } // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index d2a5169991..d2f7cf6216 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -20,30 +20,23 @@ limitations under the License. */ namespace paddle { namespace platform { -class CPUDeviceContext; - class DeviceContext { public: virtual ~DeviceContext() {} template - DeviceType get_eigen_device(); + inline DeviceType get_eigen_device(); virtual Place GetPlace() const = 0; }; -template <> -Eigen::DefaultDevice DeviceContext::get_eigen_device() { - return static_cast(this)->eigen_handle(); -} - class CPUDeviceContext : public DeviceContext { public: - Eigen::DefaultDevice eigen_handle() { - if (!eigen_handle_) { - eigen_handle_ = new Eigen::DefaultDevice(); + Eigen::DefaultDevice eigen_device() { + if (!eigen_device_) { + eigen_device_ = new Eigen::DefaultDevice(); } - return *eigen_handle_; + return *eigen_device_; } Place GetPlace() const override { @@ -52,7 +45,12 @@ class CPUDeviceContext : public DeviceContext { } private: - Eigen::DefaultDevice* eigen_handle_{nullptr}; + Eigen::DefaultDevice* eigen_device_{nullptr}; }; + +template <> +Eigen::DefaultDevice DeviceContext::get_eigen_device() { + return dynamic_cast(this)->eigen_device(); +} } // namespace platform } // namespace paddle diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 8390e97b15..abaaaececf 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -13,15 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" -#include "paddle/platform/cuda_device.h" +#include "paddle/platform/cuda_device_context.h" +using DEVICE_GPU = Eigen::GpuDevice; TEST(Device, Init) { int count = paddle::platform::GetDeviceCount(); for (int i = 0; i < count; i++) { paddle::platform::DeviceContext* device_context = new paddle::platform::CUDADeviceContext(i); Eigen::GpuDevice gpu_device = - device_context->get_eigen_device(); + device_context->template get_eigen_device(); ASSERT_NE(nullptr, gpu_device.stream()); delete device_context; } From 85806e75850aa6284afa4456daab7990186a0493 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 12 Jul 2017 16:23:10 +0800 Subject: [PATCH 11/21] follow comments --- paddle/platform/CMakeLists.txt | 9 +- paddle/platform/cuda_device_context.h | 159 ------------------------- paddle/platform/device_context.cc | 24 ++++ paddle/platform/device_context.h | 151 +++++++++++++++++++++-- paddle/platform/device_context_test.cc | 2 +- 5 files changed, 177 insertions(+), 168 deletions(-) delete mode 100644 paddle/platform/cuda_device_context.h create mode 100644 paddle/platform/device_context.cc diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index e93592cc4c..358d14f455 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -5,4 +5,11 @@ nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) -nv_test(device_context_test SRCS device_context_test.cc DEPS dynload_cuda dynamic_loader eigen3 place) +IF(WITH_GPU) + set(GPU_CTX_DEPS dynload_cuda dynamic_loader) +ELSE() + set(GPU_CTX_DEPS) +ENDIF() + +cc_library(device_context SRCS device_context.cc DEPS place eigen3 ${GPU_CTX_DEPS}) +nv_test(device_context_test SRCS device_context_test.cc DEPS device_context glog gflags) diff --git a/paddle/platform/cuda_device_context.h b/paddle/platform/cuda_device_context.h deleted file mode 100644 index 8a9d15e8a8..0000000000 --- a/paddle/platform/cuda_device_context.h +++ /dev/null @@ -1,159 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/enforce.h" -#include "paddle/platform/cuda.h" -#include "paddle/platform/dynload/cublas.h" -#include "paddle/platform/dynload/cudnn.h" -#include "paddle/platform/dynload/curand.h" -#define EIGEN_USE_GPU -#include "paddle/platform/device_context.h" -#include "paddle/platform/place.h" -#include "unsupported/Eigen/CXX11/Tensor" - -namespace paddle { -namespace platform { - -class GPUPlaceGuard { - public: - explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { - if (previous_ != new_place) { - paddle::platform::SetDeviceId(new_place.device); - } - } - - ~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); } - - private: - GPUPlace previous_; -}; - -class CUDADeviceContext : public DeviceContext { - public: - explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { - GPUPlaceGuard guard(gpu_place_); - paddle::platform::throw_on_error(cudaStreamCreate(&stream_), - "cudaStreamCreate failed"); - eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); - eigen_device_ = new Eigen::GpuDevice(eigen_stream_); - } - - Place GetPlace() const override { - Place retv = GPUPlace(); - return retv; - } - - void Wait() { - paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), - "cudaStreamSynchronize failed"); - } - - cudaStream_t stream() { return stream_; } - - Eigen::GpuDevice eigen_device() { return *eigen_device_; } - - cublasHandle_t cublas_handle() { - if (!blas_handle_) { - GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) == - CUBLAS_STATUS_SUCCESS, - "cublasCreate failed"); - PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream( - blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, - "cublasSetStream failed"); - } - return blas_handle_; - } - - cudnnHandle_t cudnn_handle() { - if (!dnn_handle_) { - GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) == - CUDNN_STATUS_SUCCESS, - "cudnnCreate failed"); - PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream( - dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, - "cudnnSetStream failed"); - } - return dnn_handle_; - } - - curandGenerator_t curand_generator() { - if (!rand_generator_) { - GPUPlaceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( - &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == - CURAND_STATUS_SUCCESS, - "curandCreateGenerator failed"); - PADDLE_ENFORCE( - paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed( - rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS, - "curandSetPseudoRandomGeneratorSeed failed"); - PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream( - rand_generator_, stream_) == CURAND_STATUS_SUCCESS, - "curandSetStream failed"); - } - return rand_generator_; - } - - ~CUDADeviceContext() { - Wait(); - if (blas_handle_) { - PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) == - CUBLAS_STATUS_SUCCESS, - "cublasDestroy failed"); - } - - if (dnn_handle_) { - PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) == - CUDNN_STATUS_SUCCESS, - "cudnnDestroy failed"); - } - - if (rand_generator_) { - PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator( - rand_generator_) == CURAND_STATUS_SUCCESS, - "curandDestroyGenerator failed"); - } - - delete eigen_stream_; - delete eigen_device_; - - paddle::platform::throw_on_error(cudaStreamDestroy(stream_), - "cudaStreamDestroy failed"); - } - - private: - GPUPlace gpu_place_; - cudaStream_t stream_; - - Eigen::CudaStreamDevice* eigen_stream_; - Eigen::GpuDevice* eigen_device_; - - cublasHandle_t blas_handle_{nullptr}; - - cudnnHandle_t dnn_handle_{nullptr}; - - int random_seed_; - curandGenerator_t rand_generator_{nullptr}; -}; - -template <> -Eigen::GpuDevice DeviceContext::get_eigen_device() { - return dynamic_cast(this)->eigen_device(); -} -} // namespace platform -} // namespace paddle diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc new file mode 100644 index 0000000000..8d800ec499 --- /dev/null +++ b/paddle/platform/device_context.cc @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace platform { +namespace dynload { +namespace dummy { +// Make DeviceContext A library. +int DUMMY_VAR_FOR_DEV_CTX = 0; + +} // namespace dummy +} // namespace dynload +} // namespace platform +} // namespace paddle \ No newline at end of file diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index d2f7cf6216..5b4b5e2999 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,9 +10,17 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include "paddle/framework/enforce.h" -#include "paddle/platform/place.h" -#include "unsupported/Eigen/CXX11/Tensor" +#ifndef PADDLE_ONLY_CPU +#include "paddle/platform/cuda.h" +#include "paddle/platform/dynload/cublas.h" +#include "paddle/platform/dynload/cudnn.h" +#include "paddle/platform/dynload/curand.h" +#define EIGEN_USE_GPU +#endif +#include +#include namespace paddle { namespace platform { @@ -23,11 +28,10 @@ namespace platform { class DeviceContext { public: virtual ~DeviceContext() {} + virtual Place GetPlace() const = 0; template inline DeviceType get_eigen_device(); - - virtual Place GetPlace() const = 0; }; class CPUDeviceContext : public DeviceContext { @@ -52,5 +56,138 @@ template <> Eigen::DefaultDevice DeviceContext::get_eigen_device() { return dynamic_cast(this)->eigen_device(); } + +#ifndef PADDLE_ONLY_CPU + +class GPUPlaceGuard { + public: + explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { + if (previous_ != new_place) { + paddle::platform::SetDeviceId(new_place.device); + } + } + + ~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); } + + private: + GPUPlace previous_; +}; + +class CUDADeviceContext : public DeviceContext { + public: + explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { + GPUPlaceGuard guard(gpu_place_); + paddle::platform::throw_on_error(cudaStreamCreate(&stream_), + "cudaStreamCreate failed"); + eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); + eigen_device_ = new Eigen::GpuDevice(eigen_stream_); + } + + Place GetPlace() const override { + Place retv = GPUPlace(); + return retv; + } + + void Wait() { + paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), + "cudaStreamSynchronize failed"); + } + + cudaStream_t stream() { return stream_; } + + Eigen::GpuDevice eigen_device() { return *eigen_device_; } + + cublasHandle_t cublas_handle() { + if (!blas_handle_) { + GPUPlaceGuard guard(gpu_place_); + PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) == + CUBLAS_STATUS_SUCCESS, + "cublasCreate failed"); + PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream( + blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, + "cublasSetStream failed"); + } + return blas_handle_; + } + + cudnnHandle_t cudnn_handle() { + if (!dnn_handle_) { + GPUPlaceGuard guard(gpu_place_); + PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) == + CUDNN_STATUS_SUCCESS, + "cudnnCreate failed"); + PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream( + dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, + "cudnnSetStream failed"); + } + return dnn_handle_; + } + + curandGenerator_t curand_generator() { + if (!rand_generator_) { + GPUPlaceGuard guard(gpu_place_); + PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( + &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == + CURAND_STATUS_SUCCESS, + "curandCreateGenerator failed"); + PADDLE_ENFORCE( + paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed( + rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS, + "curandSetPseudoRandomGeneratorSeed failed"); + PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream( + rand_generator_, stream_) == CURAND_STATUS_SUCCESS, + "curandSetStream failed"); + } + return rand_generator_; + } + + ~CUDADeviceContext() { + Wait(); + if (blas_handle_) { + PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) == + CUBLAS_STATUS_SUCCESS, + "cublasDestroy failed"); + } + + if (dnn_handle_) { + PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) == + CUDNN_STATUS_SUCCESS, + "cudnnDestroy failed"); + } + + if (rand_generator_) { + PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator( + rand_generator_) == CURAND_STATUS_SUCCESS, + "curandDestroyGenerator failed"); + } + + delete eigen_stream_; + delete eigen_device_; + + paddle::platform::throw_on_error(cudaStreamDestroy(stream_), + "cudaStreamDestroy failed"); + } + + private: + GPUPlace gpu_place_; + cudaStream_t stream_; + + Eigen::CudaStreamDevice* eigen_stream_; + Eigen::GpuDevice* eigen_device_; + + cublasHandle_t blas_handle_{nullptr}; + + cudnnHandle_t dnn_handle_{nullptr}; + + int random_seed_; + curandGenerator_t rand_generator_{nullptr}; +}; + +template <> +Eigen::GpuDevice DeviceContext::get_eigen_device() { + return dynamic_cast(this)->eigen_device(); +} +#endif + } // namespace platform } // namespace paddle diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index abaaaececf..913e3c0aa9 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/platform/device_context.h" #include "gtest/gtest.h" -#include "paddle/platform/cuda_device_context.h" using DEVICE_GPU = Eigen::GpuDevice; TEST(Device, Init) { From 6986a89331673df8c449a2894747c027cc52cc34 Mon Sep 17 00:00:00 2001 From: gangliao Date: Wed, 12 Jul 2017 17:02:53 +0800 Subject: [PATCH 12/21] FIX: add -lrt for link --- cmake/generic.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 83e3d155d0..a30cdeff62 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -93,7 +93,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}) if(NOT APPLE) find_package(Threads REQUIRED) link_libraries(${CMAKE_THREAD_LIBS_INIT}) - set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl") + set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl -lrt") endif(NOT APPLE) function(merge_static_libs TARGET_NAME) From a07deac9efb1dc2ff7cea2a9534847512533a8b1 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 12 Jul 2017 09:09:12 +0000 Subject: [PATCH 13/21] follow comments --- paddle/platform/device_context.cc | 20 +++++++++++++------- paddle/platform/device_context.h | 11 +---------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 8d800ec499..25ff352e8c 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -13,12 +13,18 @@ limitations under the License. */ namespace paddle { namespace platform { -namespace dynload { -namespace dummy { -// Make DeviceContext A library. -int DUMMY_VAR_FOR_DEV_CTX = 0; -} // namespace dummy -} // namespace dynload +template <> +Eigen::DefaultDevice DeviceContext::get_eigen_device() { + return reinterpret_cast(this)->eigen_device(); +} + +#ifndef PADDLE_ONLY_CPU +template <> +Eigen::GpuDevice DeviceContext::get_eigen_device() { + return reinterpret_cast(this)->eigen_device(); +} +#endif + } // namespace platform -} // namespace paddle \ No newline at end of file +} // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 5b4b5e2999..d6cf114216 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -31,7 +31,7 @@ class DeviceContext { virtual Place GetPlace() const = 0; template - inline DeviceType get_eigen_device(); + DeviceType get_eigen_device(); }; class CPUDeviceContext : public DeviceContext { @@ -52,11 +52,6 @@ class CPUDeviceContext : public DeviceContext { Eigen::DefaultDevice* eigen_device_{nullptr}; }; -template <> -Eigen::DefaultDevice DeviceContext::get_eigen_device() { - return dynamic_cast(this)->eigen_device(); -} - #ifndef PADDLE_ONLY_CPU class GPUPlaceGuard { @@ -183,10 +178,6 @@ class CUDADeviceContext : public DeviceContext { curandGenerator_t rand_generator_{nullptr}; }; -template <> -Eigen::GpuDevice DeviceContext::get_eigen_device() { - return dynamic_cast(this)->eigen_device(); -} #endif } // namespace platform From be2c1a3b99947ace0717ac79ebbf1b25ecb1055d Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 12 Jul 2017 09:41:33 +0000 Subject: [PATCH 14/21] follow comments --- paddle/platform/device_context.cc | 4 ++-- paddle/platform/device_context.h | 26 ++++++++++++-------------- paddle/platform/device_context_test.cc | 8 ++++---- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 25ff352e8c..960ef0a595 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -15,13 +15,13 @@ namespace paddle { namespace platform { template <> -Eigen::DefaultDevice DeviceContext::get_eigen_device() { +Eigen::DefaultDevice* DeviceContext::get_eigen_device() { return reinterpret_cast(this)->eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> -Eigen::GpuDevice DeviceContext::get_eigen_device() { +Eigen::GpuDevice* DeviceContext::get_eigen_device() { return reinterpret_cast(this)->eigen_device(); } #endif diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index d6cf114216..94f54d705d 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -31,16 +31,16 @@ class DeviceContext { virtual Place GetPlace() const = 0; template - DeviceType get_eigen_device(); + DeviceType* get_eigen_device(); }; class CPUDeviceContext : public DeviceContext { public: - Eigen::DefaultDevice eigen_device() { + Eigen::DefaultDevice* eigen_device() { if (!eigen_device_) { - eigen_device_ = new Eigen::DefaultDevice(); + eigen_device_.reset(new Eigen::DefaultDevice()); } - return *eigen_device_; + return eigen_device_.get(); } Place GetPlace() const override { @@ -49,7 +49,7 @@ class CPUDeviceContext : public DeviceContext { } private: - Eigen::DefaultDevice* eigen_device_{nullptr}; + std::unique_ptr eigen_device_; }; #ifndef PADDLE_ONLY_CPU @@ -74,8 +74,8 @@ class CUDADeviceContext : public DeviceContext { GPUPlaceGuard guard(gpu_place_); paddle::platform::throw_on_error(cudaStreamCreate(&stream_), "cudaStreamCreate failed"); - eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); - eigen_device_ = new Eigen::GpuDevice(eigen_stream_); + eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_)); + eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); } Place GetPlace() const override { @@ -90,7 +90,7 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream() { return stream_; } - Eigen::GpuDevice eigen_device() { return *eigen_device_; } + Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); } cublasHandle_t cublas_handle() { if (!blas_handle_) { @@ -155,10 +155,8 @@ class CUDADeviceContext : public DeviceContext { rand_generator_) == CURAND_STATUS_SUCCESS, "curandDestroyGenerator failed"); } - - delete eigen_stream_; - delete eigen_device_; - + eigen_stream_.reset(); + eigen_device_.reset(); paddle::platform::throw_on_error(cudaStreamDestroy(stream_), "cudaStreamDestroy failed"); } @@ -167,8 +165,8 @@ class CUDADeviceContext : public DeviceContext { GPUPlace gpu_place_; cudaStream_t stream_; - Eigen::CudaStreamDevice* eigen_stream_; - Eigen::GpuDevice* eigen_device_; + std::unique_ptr eigen_stream_; + std::unique_ptr eigen_device_; cublasHandle_t blas_handle_{nullptr}; diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 913e3c0aa9..af2ce17fc2 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -21,9 +21,9 @@ TEST(Device, Init) { for (int i = 0; i < count; i++) { paddle::platform::DeviceContext* device_context = new paddle::platform::CUDADeviceContext(i); - Eigen::GpuDevice gpu_device = + Eigen::GpuDevice* gpu_device = device_context->template get_eigen_device(); - ASSERT_NE(nullptr, gpu_device.stream()); + ASSERT_NE(nullptr, gpu_device); delete device_context; } } @@ -33,8 +33,8 @@ TEST(Device, CUDADeviceContext) { for (int i = 0; i < count; i++) { paddle::platform::CUDADeviceContext* device_context = new paddle::platform::CUDADeviceContext(i); - Eigen::GpuDevice gpu_device = device_context->eigen_device(); - ASSERT_NE(nullptr, gpu_device.stream()); + Eigen::GpuDevice* gpu_device = device_context->eigen_device(); + ASSERT_NE(nullptr, gpu_device); cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); ASSERT_NE(nullptr, cudnn_handle); cublasHandle_t cublas_handle = device_context->cublas_handle(); From 70d937c595fb7f945bfae21d7d2a81f2a7ccc45a Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 12 Jul 2017 10:17:26 +0000 Subject: [PATCH 15/21] add memory header file --- paddle/platform/device_context.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 94f54d705d..7de07d06be 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -20,6 +20,7 @@ limitations under the License. */ #define EIGEN_USE_GPU #endif #include +#include #include namespace paddle { From be441f7d162bd9638e07a6558cf2de9dd3c8b412 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Wed, 12 Jul 2017 20:36:40 +0800 Subject: [PATCH 16/21] test OpKernel (#2820) Add unit test for OpKernel --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/op_registry.h | 14 ++++---- paddle/framework/op_registry_test.cc | 4 +-- paddle/framework/operator_test.cc | 52 +++++++++++++++++++++++++++- 4 files changed, 61 insertions(+), 11 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index aac49fdb7a..b8642ca22a 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -12,7 +12,7 @@ cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc protobuf) -cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) +cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry place) cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 02c99d50bb..248c7a1a3b 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -147,13 +147,13 @@ class OpRegisterHelper { } }; -#define REGISTER_OP(__op_class, __op_maker_class, __op_type) \ - class __op_class##Register { \ - private: \ - const static OpRegisterHelper<__op_class, __op_maker_class> reg; \ - }; \ - const OpRegisterHelper<__op_class, __op_maker_class> \ - __op_class##Register::reg(#__op_type); +#define REGISTER_OP(type, op_class, op_maker_class) \ + class op_class##Register { \ + private: \ + const static OpRegisterHelper reg; \ + }; \ + const OpRegisterHelper op_class##Register::reg( \ + #type) } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index f5d45a80bb..f5162fb870 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -26,7 +26,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { } }; -REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) +REGISTER_OP(cos_sim, CosineOp, CosineOpProtoAndCheckerMaker); class MyTestOp : public OperatorBase { public: @@ -53,7 +53,7 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { } }; -REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) +REGISTER_OP(my_test_op, MyTestOp, MyTestOpProtoAndCheckerMaker); } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 86f45f108a..be8c4be2d4 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -45,7 +45,7 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { } }; -REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator) +REGISTER_OP(test_operator, OperatorTest, OperatorTestProtoAndCheckerMaker); TEST(OperatorBase, all) { OpDesc op_desc; @@ -69,5 +69,55 @@ TEST(OperatorBase, all) { delete op; } +class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of test op"); + AddOutput("output", "output of test op"); + AddAttr("scale", "scale of cosine op") + .SetDefault(1.0) + .LargerThan(0.0); + AddType("test_operator"); + AddComment("This is test op"); + } +}; + +class OpWithKernelTest : public OperatorWithKernel { + public: + void InferShape(const std::shared_ptr& scope) const override {} +}; + +class CPUKernelTest : public OpKernel { + public: + void Compute(const KernelContext& context) const { + float scale = context.op_.GetAttr("scale"); + ASSERT_NEAR(scale, 3.14, 1e-5); + std::cout << "this is cpu kernel" << std::endl; + std::cout << context.op_.DebugString() << std::endl; + } +}; + +REGISTER_OP(op_with_kernel, OpWithKernelTest, OpKernelTestProtoAndCheckerMaker); +REGISTER_OP_KERNEL(op_with_kernel, platform::CPUPlace, CPUKernelTest); + +TEST(OpKernel, all) { + OpDesc op_desc; + op_desc.set_type("op_with_kernel"); + *op_desc.mutable_inputs()->Add() = "IN1"; + *op_desc.mutable_outputs()->Add() = "OUT1"; + auto attr = op_desc.mutable_attrs()->Add(); + attr->set_name("scale"); + attr->set_type(paddle::framework::AttrType::FLOAT); + attr->set_f(3.14); + + platform::CPUDeviceContext cpu_device_context; + auto scope = std::make_shared(); + + OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); + op->Run(scope, cpu_device_context); + + delete op; +} } // namespace framework } // namespace paddle \ No newline at end of file From e4be077ffa44465fe19f47c892164452fdecfdfb Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Tue, 11 Jul 2017 17:47:28 -0400 Subject: [PATCH 17/21] Add go testing into cmake and fix libpaddle_go_optimizer.a link path --- CMakeLists.txt | 4 ++++ cmake/generic.cmake | 21 ++++++++++++--------- go/CMakeLists.txt | 3 +++ go/master/CMakeLists.txt | 3 +++ go/pserver/CMakeLists.txt | 3 +++ go/pserver/client/c/CMakeLists.txt | 8 ++++++++ go/pserver/optimizer.go | 3 +-- go/utils/networkhelper/CMakeLists.txt | 3 +++ paddle/CMakeLists.txt | 1 - 9 files changed, 37 insertions(+), 12 deletions(-) create mode 100644 go/master/CMakeLists.txt create mode 100644 go/pserver/CMakeLists.txt create mode 100644 go/utils/networkhelper/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 2c713db3e3..6bc6a8077c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -140,6 +140,10 @@ endif(USE_NNPACK) add_subdirectory(proto) +# "add_subdirectory(go)" should be placed after the following loine, +# because it depends on paddle/optimizer. +add_subdirectory(paddle/optimizer) + # "add_subdirectory(paddle)" and "add_subdirectory(python)" should be # placed after this block, because they depends on it. if(WITH_GOLANG) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 83e3d155d0..f88b9dff2b 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -301,7 +301,7 @@ function(go_library TARGET_NAME) file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) - # FIXME: link path + add_custom_command(TARGET ${TARGET_NAME} POST_BUILD COMMAND rm "${${TARGET_NAME}_LIB_PATH}" # Golang build source code @@ -309,7 +309,7 @@ function(go_library TARGET_NAME) -o "${${TARGET_NAME}_LIB_PATH}" "./${CMAKE_CURRENT_SOURCE_REL_DIR}/${GO_SOURCE}" # must run under GOPATH - WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") + WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") add_dependencies(${TARGET_NAME} go_vendor) endfunction(go_library) @@ -322,8 +322,8 @@ function(go_binary TARGET_NAME) # FIXME: link path add_custom_command(OUTPUT ${TARGET_NAME}_timestamp - COMMAND env LIBRARY_PATH=${CMAKE_BINARY_DIR}/go/pserver/client/c/:$ENV{LIBRARY_PATH} - GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build + COMMAND env LIBRARY_PATH=${CMAKE_BINARY_DIR}/go/pserver/client/c/:$ENV{LIBRARY_PATH} + GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" "./${CMAKE_CURRENT_SOURCE_REL_DIR}/${go_binary_SRCS}" WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") @@ -335,15 +335,18 @@ endfunction(go_binary) function(go_test TARGET_NAME) set(options OPTIONAL) set(oneValueArgs "") - set(multiValueArgs SRCS DEPS) + set(multiValueArgs DEPS) cmake_parse_arguments(go_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - add_custom_command(OUTPUT ${TARGET_NAME}_timestamp + string(REPLACE "${PADDLE_GO_PATH}" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) + add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${go_test_DEPS}) + add_custom_command(TARGET ${TARGET_NAME} POST_BUILD COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test -c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" - ${go_test_SRCS} + ".${CMAKE_CURRENT_SOURCE_REL_DIR}" + WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") + add_test(NAME ${TARGET_NAME} + COMMAND ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) - add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS}) - add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}) endfunction(go_test) function(proto_library TARGET_NAME) diff --git a/go/CMakeLists.txt b/go/CMakeLists.txt index f00c70a058..18fee46d19 100644 --- a/go/CMakeLists.txt +++ b/go/CMakeLists.txt @@ -17,3 +17,6 @@ add_subdirectory(pserver/client/c) add_subdirectory(cmd/pserver) add_subdirectory(cmd/master) add_subdirectory(master/c) +add_subdirectory(master) +add_subdirectory(pserver) +add_subdirectory(utils/networkhelper) diff --git a/go/master/CMakeLists.txt b/go/master/CMakeLists.txt new file mode 100644 index 0000000000..30531e6469 --- /dev/null +++ b/go/master/CMakeLists.txt @@ -0,0 +1,3 @@ +if(WITH_TESTING) + go_test(master_test) +endif() diff --git a/go/pserver/CMakeLists.txt b/go/pserver/CMakeLists.txt new file mode 100644 index 0000000000..6267040a6e --- /dev/null +++ b/go/pserver/CMakeLists.txt @@ -0,0 +1,3 @@ +if(WITH_TESTING) + go_test(pserver_test DEPS paddle_go_optimizer) +endif() diff --git a/go/pserver/client/c/CMakeLists.txt b/go/pserver/client/c/CMakeLists.txt index 93a0a27f85..c6333eab55 100644 --- a/go/pserver/client/c/CMakeLists.txt +++ b/go/pserver/client/c/CMakeLists.txt @@ -1,5 +1,13 @@ cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf) target_link_libraries(paddle_go_optimizer stdc++ m) + +# Copy library to the required place. +# See: go/pserver/optimizer.go: +# // #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm +add_custom_command(TARGET paddle_go_optimizer POST_BUILD + COMMAND cp "${CMAKE_CURRENT_BINARY_DIR}/libpaddle_go_optimizer.a" "${CMAKE_CURRENT_SOURCE_DIR}" + ) + go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer) if(WITH_TESTING) # FIXME: this test requires pserver which is not managed by the test diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go index 2d7882d1a7..0ebf4a26fa 100644 --- a/go/pserver/optimizer.go +++ b/go/pserver/optimizer.go @@ -1,8 +1,7 @@ package pserver // #cgo CFLAGS: -I ../../ -// //FIXME: ldflags contain "build" path -// #cgo LDFLAGS: ${SRCDIR}/../../build/go/pserver/client/c/libpaddle_go_optimizer.a -lstdc++ -lm +// #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm // #include "paddle/optimizer/optimizer.h" // #include // #include diff --git a/go/utils/networkhelper/CMakeLists.txt b/go/utils/networkhelper/CMakeLists.txt new file mode 100644 index 0000000000..db6cf211d8 --- /dev/null +++ b/go/utils/networkhelper/CMakeLists.txt @@ -0,0 +1,3 @@ +if(WITH_TESTING) + go_test(network_helper_test) +endif() diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 58a35564f8..0b5e9a2599 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -8,7 +8,6 @@ add_subdirectory(gserver) add_subdirectory(pserver) add_subdirectory(trainer) add_subdirectory(scripts) -add_subdirectory(optimizer) add_subdirectory(string) if(Boost_FOUND) From 59287cd1cad1e2d6006eff68d8f025af3dd0c310 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 12 Jul 2017 22:30:44 +0000 Subject: [PATCH 18/21] add .gitignore --- go/pserver/client/c/.gitignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 go/pserver/client/c/.gitignore diff --git a/go/pserver/client/c/.gitignore b/go/pserver/client/c/.gitignore new file mode 100644 index 0000000000..4bf05c8538 --- /dev/null +++ b/go/pserver/client/c/.gitignore @@ -0,0 +1 @@ +libpaddle_go_optimizer.a From 2231b92a89ea560934be92987c27068be398c6fd Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 12 Jul 2017 23:20:06 +0000 Subject: [PATCH 19/21] go_binary: remove hardcoded library link path --- cmake/generic.cmake | 5 +---- go/cmd/master/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index b13400d125..71ee266611 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -320,14 +320,11 @@ function(go_binary TARGET_NAME) cmake_parse_arguments(go_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) - # FIXME: link path add_custom_command(OUTPUT ${TARGET_NAME}_timestamp - COMMAND env LIBRARY_PATH=${CMAKE_BINARY_DIR}/go/pserver/client/c/:$ENV{LIBRARY_PATH} - GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build + COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" "./${CMAKE_CURRENT_SOURCE_REL_DIR}/${go_binary_SRCS}" WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") - # TODO: don't know what ${TARGET_NAME}_link does add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${TARGET_NAME}_timestamp ${go_binary_DEPS}) install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin) endfunction(go_binary) diff --git a/go/cmd/master/CMakeLists.txt b/go/cmd/master/CMakeLists.txt index 1058ffa86b..9e149967e7 100644 --- a/go/cmd/master/CMakeLists.txt +++ b/go/cmd/master/CMakeLists.txt @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -go_binary(master SRC master.go DEPS paddle_go_optimizer) +go_binary(master SRC master.go) From b04986da9f57cfba0657194c7e35b7e9229a6676 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 12 Jul 2017 23:48:06 +0000 Subject: [PATCH 20/21] add pserver client test --- go/CMakeLists.txt | 1 + go/pserver/client/CMakeLists.txt | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 go/pserver/client/CMakeLists.txt diff --git a/go/CMakeLists.txt b/go/CMakeLists.txt index 18fee46d19..29ce909c64 100644 --- a/go/CMakeLists.txt +++ b/go/CMakeLists.txt @@ -19,4 +19,5 @@ add_subdirectory(cmd/master) add_subdirectory(master/c) add_subdirectory(master) add_subdirectory(pserver) +add_subdirectory(pserver/client) add_subdirectory(utils/networkhelper) diff --git a/go/pserver/client/CMakeLists.txt b/go/pserver/client/CMakeLists.txt new file mode 100644 index 0000000000..0052bb460b --- /dev/null +++ b/go/pserver/client/CMakeLists.txt @@ -0,0 +1,3 @@ +if(WITH_TESTING) + go_test(pserver_client_test DEPS paddle_go_optimizer) +endif() From 19bfb8a1f2153f0d5368808bc09580c7e4c7b07c Mon Sep 17 00:00:00 2001 From: Yancey Date: Thu, 13 Jul 2017 09:52:26 +0800 Subject: [PATCH 21/21] PServer recovery from checkpoint (#2741) * Server recovery from checkpoint --- .gitignore | 3 ++ go/cmd/pserver/pserver.go | 39 +++++++------- go/glide.lock | 14 ++--- go/glide.yaml | 1 + go/pserver/etcd_client.go | 22 ++++++-- go/pserver/service.go | 104 ++++++++++++++++++++++++++------------ 6 files changed, 121 insertions(+), 62 deletions(-) diff --git a/.gitignore b/.gitignore index 5c2fb134ae..c84b2fc8c7 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ cmake-build-* # generated while compiling python/paddle/v2/framework/core.so +CMakeFiles +cmake_install.cmake + diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index 0ecb1242c3..48351ab6d0 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -8,6 +8,7 @@ import ( "time" "github.com/namsral/flag" + "github.com/topicai/candy" "github.com/PaddlePaddle/Paddle/go/pserver" log "github.com/sirupsen/logrus" @@ -18,53 +19,47 @@ func main() { index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", "comma separated endpoint string for pserver to connect to etcd") - etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") + etcdTimeout := flag.Duration("etcd-timeout", 5*time.Second, "timeout for etcd calls") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") - checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds") + checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds") logLevel := flag.String("log-level", "info", "log level, possible values: debug, info, warning, error, fatal, panic") flag.Parse() level, err := log.ParseLevel(*logLevel) - if err != nil { - panic(err) - } + candy.Must(err) + log.SetLevel(level) var idx int - var cp pserver.Checkpoint + + var cp *pserver.Checkpoint var e *pserver.EtcdClient if *index >= 0 { idx = *index } else { - timeout := time.Second * time.Duration((*etcdTimeout)) - e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) + e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout) idx, err = e.Register() + candy.Must(err) + + cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e) if err != nil { - panic(err) + log.Errorf("Fetch checkpoint failed, %s", err) } } s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) - if err != nil { - panic(err) - } + candy.Must(err) + err = rpc.Register(s) - if err != nil { - panic(err) - } + candy.Must(err) rpc.HandleHTTP() l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) - if err != nil { - panic(err) - } + candy.Must(err) log.Infof("start pserver at port %d", *port) err = http.Serve(l, nil) - - if err != nil { - panic(err) - } + candy.Must(err) } diff --git a/go/glide.lock b/go/glide.lock index 190a222338..f71ae643d6 100644 --- a/go/glide.lock +++ b/go/glide.lock @@ -1,8 +1,8 @@ -hash: b8f18ce6784bd3fadd9fed0b8443e7b658234ea785ae1f220723ae2c1f652aa7 -updated: 2017-06-27T14:05:48.925262819+08:00 +hash: a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855 +updated: 2017-07-11T10:04:40.786745417+08:00 imports: - name: github.com/coreos/etcd - version: 61fc123e7a8b14a0a258aa3f5c4159861b1ec2e7 + version: cb2a496c4ddd1c87a9f280e116649b599999ec79 subpackages: - auth/authpb - clientv3 @@ -22,7 +22,9 @@ imports: - name: github.com/PaddlePaddle/recordio version: edfb82af0739c84f241c87390ec5649c7b28c129 - name: github.com/sirupsen/logrus - version: 202f25545ea4cf9b191ff7f846df5d87c9382c2b + version: 7f976d3a76720c4c27af2ba716b85d2e0a7e38b1 +- name: github.com/topicai/candy + version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc - name: golang.org/x/net version: c8c74377599bd978aee1cf3b9b63a8634051cec2 subpackages: @@ -34,11 +36,11 @@ imports: - lex/httplex - trace - name: golang.org/x/sys - version: f7928cfef4d09d1b080aa2b6fd3ca9ba1567c733 + version: abf9c25f54453410d0c6668e519582a9e1115027 subpackages: - unix - name: golang.org/x/text - version: 4e9ab9ee170f2a39bd66c92b3e0a47ff47a4bc77 + version: cfdf022e86b4ecfb646e1efbd7db175dd623a8fa subpackages: - secure/bidirule - transform diff --git a/go/glide.yaml b/go/glide.yaml index 05c5d15ca2..ab472c7cda 100644 --- a/go/glide.yaml +++ b/go/glide.yaml @@ -10,3 +10,4 @@ import: version: ^1.7.4-pre - package: github.com/sirupsen/logrus version: ^1.0.0 +- package: github.com/topicai/candy diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go index 1f77787150..4a694b97f4 100644 --- a/go/pserver/etcd_client.go +++ b/go/pserver/etcd_client.go @@ -16,7 +16,7 @@ import ( const ( // PsDesired is etcd path for store desired pserver count PsDesired = "/ps_desired" - // PsAddr is the base dir for pserver to store their addr + // PsPath is the base dir for pserver to store their addr PsPath = "/ps/" // PsCheckpoint is the etcd path for store checkpoints information PsCheckpoint = "/checkpoints/" @@ -189,9 +189,25 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { return idx, nil } +// GetKey gets the value by the specified key +func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + resp, err := e.etcdClient.Get(ctx, key) + cancel() + if err != nil { + return []byte{}, err + } + kvs := resp.Kvs + if len(kvs) == 0 { + return []byte{}, nil + } + v := kvs[0].Value + return v, nil +} + // PutKey put into etcd with value by key specified -func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) +func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) _, err := e.etcdClient.Put(ctx, key, string(value)) cancel() if err != nil { diff --git a/go/pserver/service.go b/go/pserver/service.go index 6b52d0d896..65db6970a7 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -9,6 +9,7 @@ import ( "encoding/json" "errors" "fmt" + "io/ioutil" "os" "path/filepath" "strconv" @@ -21,14 +22,14 @@ import ( // ElementType is the type of elements of a Parameter. type ElementType int +// RPC error message. const ( - // AlreadyInitialized is true if pserver is initialized - AlreadyInitialized = "pserver already initialized" - // Uninitialized is true if pserver not fully initialized - Uninitialized = "pserver not fully initialized" + AlreadyInitialized = "pserver already initialized" + Uninitialized = "pserver not fully initialized" + CheckpointMD5Failed = "checkpoint file MD5 validation failed" ) -// Supported element types +// Supported element types. const ( Int32 ElementType = iota UInt32 @@ -51,21 +52,15 @@ type ParameterWithConfig struct { Config []byte // parameter configuration in Proto Buffer format } -// ParameterCheckpoint is Parameter and State checkpoint -type ParameterCheckpoint struct { - ParamConfig ParameterWithConfig - State []byte -} - -// checkpoint signature +// checkpointMeta saves checkpoint metadata type checkpointMeta struct { UUID string `json:"uuid"` - Md5sum string `json:"md5sum"` - Timestamp string `json:"timestamp"` + MD5 string `json:"md5"` + Timestamp int64 `json:"timestamp"` } // Checkpoint is the pserver shard persist in file -type Checkpoint []ParameterCheckpoint +type Checkpoint []parameterCheckpoint // Gradient is the gradient of the parameter. type Gradient Parameter @@ -81,12 +76,53 @@ type Service struct { optMap map[string]*optimizer } +// parameterCheckpoint saves parameter checkpoint +type parameterCheckpoint struct { + ParameterWithConfig + State []byte +} + +// NewCheckpointFromFile loads parameters and state from checkpoint file +func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, error) { + v, err := e.GetKey(PsPath+string(idx), 3*time.Second) + if err != nil { + return nil, err + } + + var cpMeta checkpointMeta + if err = json.Unmarshal(v, &cpMeta); err != nil { + return nil, err + } + + fn := filepath.Join(cpPath, cpMeta.UUID) + if _, err = os.Stat(fn); os.IsNotExist(err) { + return nil, err + } + content, err := ioutil.ReadFile(fn) + if err != nil { + return nil, err + } + + h := md5.New() + md5 := hex.EncodeToString(h.Sum(content)) + if md5 != cpMeta.MD5 { + return nil, errors.New(CheckpointMD5Failed) + } + + dec := gob.NewDecoder(bytes.NewReader(content)) + cp := &Checkpoint{} + if err = dec.Decode(cp); err != nil { + return nil, err + } + return cp, nil +} + // NewService creates a new service, will bypass etcd registration if no -// endpoints specified. -func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { +// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint. +func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp *Checkpoint) (*Service, error) { s := &Service{ idx: idx, - checkpointInterval: time.Second * time.Duration(seconds), + checkpointInterval: interval, checkpointPath: path, client: client, } @@ -94,10 +130,12 @@ func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkp s.initialized = make(chan struct{}) if cp != nil { - for _, item := range cp { - p := item.ParamConfig - st := item.State - s.optMap[p.Param.Name] = newOptimizer(p, st) + for _, item := range *cp { + p := ParameterWithConfig{ + Param: item.Param, + Config: item.Config, + } + s.optMap[p.Param.Name] = newOptimizer(p, item.State) } } return s, nil @@ -186,13 +224,13 @@ func (s *Service) doCheckpoint() error { s.mu.Lock() defer s.mu.Unlock() - cp := make([]ParameterCheckpoint, 0, len(s.optMap)) + cp := make([]parameterCheckpoint, len(s.optMap)) index := 0 for name, opt := range s.optMap { - var pc ParameterCheckpoint - pc.ParamConfig.Param.Name = name - pc.ParamConfig.Param.ElementType = opt.elementType - pc.ParamConfig.Param.Content = opt.GetWeights() + var pc parameterCheckpoint + pc.Param.Name = name + pc.Param.ElementType = opt.elementType + pc.Param.Content = opt.GetWeights() pc.State = opt.GetStates() cp[index] = pc index++ @@ -206,12 +244,12 @@ func (s *Service) doCheckpoint() error { cpMeta := checkpointMeta{} cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) - cpMeta.Timestamp = time.Now().String() + cpMeta.Timestamp = time.Now().UnixNano() h := md5.New() - cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes())) + cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes())) cpMetajson, _ := json.Marshal(cpMeta) - err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) + err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second) if err != nil { return err } @@ -219,7 +257,11 @@ func (s *Service) doCheckpoint() error { log.Info("checkpoint does not exists.") } else { err = os.Remove(cpMeta.UUID) - log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID) + if err != nil { + log.Infof("Removing checkpoint %s failed", cpMeta.UUID) + } else { + log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID) + } } f, err := os.Create(cpMeta.UUID) defer f.Close()