From 3ba7a738f3f3e77240d026db57692d66bc9481ed Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 3 Jul 2017 20:37:42 +0800 Subject: [PATCH 1/7] add dynamic_load --- paddle/platform/cublas.h | 87 +++++++++++++++++ paddle/platform/cudnn.h | 114 ++++++++++++++++++++++ paddle/platform/curand.h | 42 ++++++++ paddle/platform/dynamic_loader.cc | 157 ++++++++++++++++++++++++++++++ paddle/platform/dynamic_loader.h | 63 ++++++++++++ 5 files changed, 463 insertions(+) create mode 100644 paddle/platform/cublas.h create mode 100644 paddle/platform/cudnn.h create mode 100644 paddle/platform/curand.h create mode 100644 paddle/platform/dynamic_loader.cc create mode 100644 paddle/platform/dynamic_loader.h diff --git a/paddle/platform/cublas.h b/paddle/platform/cublas.h new file mode 100644 index 0000000000..70c9713325 --- /dev/null +++ b/paddle/platform/cublas.h @@ -0,0 +1,87 @@ +#include +#include "paddle/platform/dynamic_loader.h" + +namespace paddle { +namespace dyload { +namespace dynload { + +std::once_flag cublas_dso_flag; +void *cublas_dso_handle = nullptr; + +/** + * The following macro definition can generate structs + * (for each function) to dynamic load cublas routine + * via operator overloading. + * + * note: default dynamic linked libs + */ +#ifdef PADDLE_USE_DSO +#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + cublasStatus_t operator()(Args... args) { \ + typedef cublasStatus_t (*cublasFunc)(Args...); \ + std::call_once(cublas_dso_flag, GetCublasDsoHandle, &cublas_dso_handle); \ + void *p_##__name = dlsym(cublas_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + } __name; // struct DynLoad__##__name +#else +#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + cublasStatus_t operator()(Args... args) { \ + return __name(args...); \ + } \ + } __name; // struct DynLoad__##__name +#endif + +#define DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) DYNAMIC_LOAD_CUBLAS_WRAP(__name) + +// include all needed cublas functions in HPPL +// clang-format off +#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasSgemv) \ + __macro(cublasDgemv) \ + __macro(cublasSgemm) \ + __macro(cublasDgemm) \ + __macro(cublasSgeam) \ + __macro(cublasDgeam) \ + +DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate) +DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy) +DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream) +DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode) +DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetriBatched) +CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP) + +#undef DYNAMIC_LOAD_CUBLAS_WRAP +#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP +#undef CUBLAS_BLAS_ROUTINE_EACH + +} /* namespace dynload */ + +// clang-format on +#ifndef PADDLE_TYPE_DOUBLE +#define CUBLAS_GEAM dynload::cublasSgeam +#define CUBLAS_GEMV dynload::cublasSgemv +#define CUBLAS_GEMM dynload::cublasSgemm +#define CUBLAS_GETRF dynload::cublasSgetrfBatched +#define CUBLAS_GETRI dynload::cublasSgetriBatched +#else +#define CUBLAS_GEAM dynload::cublasDgeam +#define CUBLAS_GEMV dynload::cublasDgemv +#define CUBLAS_GEMM dynload::cublasDgemm +#define CUBLAS_GETRF dynload::cublasDgetrfBatched +#define CUBLAS_GETRI dynload::cublasDgetriBatched +#endif +} // namespace dyload +} // namespace paddle diff --git a/paddle/platform/cudnn.h b/paddle/platform/cudnn.h new file mode 100644 index 0000000000..ab878cd555 --- /dev/null +++ b/paddle/platform/cudnn.h @@ -0,0 +1,114 @@ +#include +#include "paddle/platform/dynamic_loader.h" + +namespace paddle { +namespace dyload { + +std::once_flag cudnn_dso_flag; +void* cudnn_dso_handle = nullptr; + +#ifdef PADDLE_USE_DSO + +#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + using cudnn_func = decltype(__name(args...)) (*)(Args...); \ + std::call_once(cudnn_dso_flag, GetCudnnDsoHandle, &cudnn_dso_handle); \ + void* p_##__name = dlsym(cudnn_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + } __name; /* struct DynLoad__##__name */ + +#else + +#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + return __name(args...); \ + } \ + } __name; /* struct DynLoad__##__name */ + +#endif + +/** + * include all needed cudnn functions in HPPL + * different cudnn version has different interfaces + **/ +// clang-format off +#define CUDNN_DNN_ROUTINE_EACH(__macro) \ + __macro(cudnnSetTensor4dDescriptor) \ + __macro(cudnnSetTensor4dDescriptorEx) \ + __macro(cudnnGetConvolutionNdForwardOutputDim) \ + __macro(cudnnGetConvolutionForwardAlgorithm) \ + __macro(cudnnCreateTensorDescriptor) \ + __macro(cudnnDestroyTensorDescriptor) \ + __macro(cudnnCreateFilterDescriptor) \ + __macro(cudnnSetFilter4dDescriptor) \ + __macro(cudnnSetPooling2dDescriptor) \ + __macro(cudnnDestroyFilterDescriptor) \ + __macro(cudnnCreateConvolutionDescriptor) \ + __macro(cudnnCreatePoolingDescriptor) \ + __macro(cudnnDestroyPoolingDescriptor) \ + __macro(cudnnSetConvolution2dDescriptor) \ + __macro(cudnnDestroyConvolutionDescriptor) \ + __macro(cudnnCreate) \ + __macro(cudnnDestroy) \ + __macro(cudnnSetStream) \ + __macro(cudnnActivationForward) \ + __macro(cudnnConvolutionForward) \ + __macro(cudnnConvolutionBackwardBias) \ + __macro(cudnnGetConvolutionForwardWorkspaceSize) \ + __macro(cudnnTransformTensor) \ + __macro(cudnnPoolingForward) \ + __macro(cudnnPoolingBackward) \ + __macro(cudnnSoftmaxBackward) \ + __macro(cudnnSoftmaxForward) \ + __macro(cudnnGetVersion) \ + __macro(cudnnGetErrorString) +CUDNN_DNN_ROUTINE_EACH(DYNAMIC_LOAD_CUDNN_WRAP) + +#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \ + __macro(cudnnAddTensor) \ + __macro(cudnnConvolutionBackwardData) \ + __macro(cudnnConvolutionBackwardFilter) +CUDNN_DNN_ROUTINE_EACH_R2(DYNAMIC_LOAD_CUDNN_WRAP) + +// APIs available after R3: +#if CUDNN_VERSION >= 3000 +#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \ + __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize) \ + __macro(cudnnGetConvolutionBackwardDataAlgorithm) \ + __macro(cudnnGetConvolutionBackwardFilterAlgorithm) \ + __macro(cudnnGetConvolutionBackwardDataWorkspaceSize) +CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DYNAMIC_LOAD_CUDNN_WRAP) +#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3 +#endif + + +// APIs available after R4: +#if CUDNN_VERSION >= 4007 +#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \ + __macro(cudnnBatchNormalizationForwardTraining) \ + __macro(cudnnBatchNormalizationForwardInference) \ + __macro(cudnnBatchNormalizationBackward) +CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DYNAMIC_LOAD_CUDNN_WRAP) +#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R4 +#endif + +// APIs in R5 +#if CUDNN_VERSION >= 5000 +#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \ + __macro(cudnnCreateActivationDescriptor) \ + __macro(cudnnSetActivationDescriptor) \ + __macro(cudnnGetActivationDescriptor) \ + __macro(cudnnDestroyActivationDescriptor) +CUDNN_DNN_ROUTINE_EACH_R5(DYNAMIC_LOAD_CUDNN_WRAP) +#undef CUDNN_DNN_ROUTINE_EACH_R5 +#endif + +#undef CUDNN_DNN_ROUTINE_EACH +// clang-format on +} // namespace dyload +} // namespace paddle diff --git a/paddle/platform/curand.h b/paddle/platform/curand.h new file mode 100644 index 0000000000..692c024e6e --- /dev/null +++ b/paddle/platform/curand.h @@ -0,0 +1,42 @@ +#include +#include "paddle/platform/dynamic_loader.h" + +namespace paddle { +namespace dyload { +#ifdef PADDLE_USE_DSO +#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + curandStatus_t operator()(Args... args) { \ + typedef curandStatus_t (*curandFunc)(Args...); \ + std::call_once(curand_dso_flag, GetCurandDsoHandle, &curand_dso_handle); \ + void *p_##__name = dlsym(curand_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + } __name; /* struct DynLoad__##__name */ +#else +#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + curandStatus_t operator()(Args... args) { \ + return __name(args...); \ + } \ + } __name; /* struct DynLoad__##__name */ +#endif + +/* include all needed curand functions in HPPL */ +// clang-format off +#define CURAND_RAND_ROUTINE_EACH(__macro) \ + __macro(curandCreateGenerator) \ + __macro(curandSetStream) \ + __macro(curandSetPseudoRandomGeneratorSeed)\ + __macro(curandGenerateUniform) \ + __macro(curandGenerateUniformDouble) +// clang-format on + +CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP) + +#undef CURAND_RAND_ROUTINE_EACH +#undef DYNAMIC_LOAD_CURAND_WRAP +} +} // namespace paddle diff --git a/paddle/platform/dynamic_loader.cc b/paddle/platform/dynamic_loader.cc new file mode 100644 index 0000000000..9036eaf642 --- /dev/null +++ b/paddle/platform/dynamic_loader.cc @@ -0,0 +1,157 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "DynamicLoader.h" +#include "Logging.h" + +DEFINE_string(cudnn_dir, "", + "Specify path for loading libcudnn.so. For instance, " + "/usr/local/cudnn/lib. If empty [default], dlopen " + "will search cudnn from LD_LIBRARY_PATH"); + +DEFINE_string(cuda_dir, "", + "Specify path for loading cuda library, such as libcublas, " + "libcurand. For instance, /usr/local/cuda/lib64. If default, " + "dlopen will search cuda from LD_LIBRARY_PATH"); + +DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so."); + +DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so."); + +static inline std::string join(const std::string& part1, + const std::string& part2) { + // directory separator + const char sep = '/'; + if (!part2.empty() && part2.front() == sep) { + return part2; + } + std::string ret; + ret.reserve(part1.size() + part2.size() + 1); + ret = part1; + if (!ret.empty() && ret.back() != sep) { + ret += sep; + } + ret += part2; + return ret; +} + +static inline void GetDsoHandleFromDefaultPath(std::string& dso_path, + void** dso_handle, + int dynload_flags) { + VLOG(3) << "Try to find library: " << dso_path + << " from default system path."; + // default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH + *dso_handle = dlopen(dso_path.c_str(), dynload_flags); + +// DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to +// bring System Integrity Projection (SIP), if dso_handle +// is null, search from default package path in Mac OS. +#if defined(__APPLE__) || defined(__OSX__) + if (nullptr == *dso_handle) { + dso_path = join("/usr/local/cuda/lib/", dso_path); + *dso_handle = dlopen(dso_path.c_str(), dynload_flags); + if (nullptr == *dso_handle) { + if (dso_path == "libcudnn.dylib") { + LOG(FATAL) + << "Note: [Recommend] copy cudnn into /usr/local/cuda/ \n" // NOLINT + << "For instance, sudo tar -xzf " + "cudnn-7.5-osx-x64-v5.0-ga.tgz -C " // NOLINT + << "/usr/local \n sudo chmod a+r " + "/usr/local/cuda/include/cudnn.h " // NOLINT + << "/usr/local/cuda/lib/libcudnn*"; + } + } + } +#endif +} + +static inline void GetDsoHandleFromSearchPath(const std::string& search_root, + const std::string& dso_name, + void** dso_handle) { + int dynload_flags = RTLD_LAZY | RTLD_LOCAL; + *dso_handle = nullptr; + + std::string dlPath = dso_name; + if (search_root.empty()) { + GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags); + } else { + // search xxx.so from custom path + dlPath = join(search_root, dso_name); + *dso_handle = dlopen(dlPath.c_str(), dynload_flags); + // if not found, search from default path + if (nullptr == *dso_handle) { + LOG(WARNING) << "Failed to find dynamic library: " << dlPath << " (" + << dlerror() << ")"; + dlPath = dso_name; + GetDsoHandleFromDefaultPath(dlPath, dso_handle, dynload_flags); + } + } + + CHECK(nullptr != *dso_handle) << "Failed to find dynamic library: " << dlPath + << " (" << dlerror() << ") \n" + << "Please specify its path correctly using " + "following ways: \n" + + << "Method. set environment variable " + "LD_LIBRARY_PATH on Linux or " + << "DYLD_LIBRARY_PATH on Mac OS. \n" + << "For instance, issue command: export " + "LD_LIBRARY_PATH=... \n" + + << "Note: After Mac OS 10.11, using the " + "DYLD_LIBRARY_PATH is impossible " + << "unless System Integrity Protection (SIP) " + "is disabled."; +} + +void GetCublasDsoHandle(void** dso_handle) { +#if defined(__APPLE__) || defined(__OSX__) + GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.dylib", dso_handle); +#else + GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublas.so", dso_handle); +#endif +} + +void GetCudnnDsoHandle(void** dso_handle) { +#if defined(__APPLE__) || defined(__OSX__) + GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.dylib", dso_handle); +#else + GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.so", dso_handle); +#endif +} + +void GetCurandDsoHandle(void** dso_handle) { +#if defined(__APPLE__) || defined(__OSX__) + GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.dylib", dso_handle); +#else + GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcurand.so", dso_handle); +#endif +} + +void GetWarpCTCDsoHandle(void** dso_handle) { +#if defined(__APPLE__) || defined(__OSX__) + GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib", dso_handle); +#else + GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so", dso_handle); +#endif +} + +void GetLapackDsoHandle(void** dso_handle) { +#if defined(__APPLE__) || defined(__OSX__) + GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.dylib", dso_handle); +#else + GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.so", dso_handle); +#endif +} diff --git a/paddle/platform/dynamic_loader.h b/paddle/platform/dynamic_loader.h new file mode 100644 index 0000000000..9b5ad21724 --- /dev/null +++ b/paddle/platform/dynamic_loader.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef DYNAMIC_LOAD_H_ +#define DYNAMIC_LOAD_H_ + +#include +#include +#include +#include + +/** + * @brief load the DSO of CUBLAS + * + * @param **dso_handle dso handler + * + */ +void GetCublasDsoHandle(void** dso_handle); + +/** + * @brief load the DSO of CUDNN + * + * @param **dso_handle dso handler + * + */ +void GetCudnnDsoHandle(void** dso_handle); + +/** + * @brief load the DSO of CURAND + * + * @param **dso_handle dso handler + * + */ +void GetCurandDsoHandle(void** dso_handle); + +/** + * @brief load the DSO of warp-ctc + * + * @param **dso_handle dso handler + * + */ +void GetWarpCTCDsoHandle(void** dso_handle); + +/** + * @brief load the DSO of lapack + * + * @param **dso_handle dso handler + * + */ +void GetLapackDsoHandle(void** dso_handle); + +#endif // DYNAMIC_LOAD_H_ From a30754b05e1ef58b5803c3d9996ed0cc69100ac5 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 3 Jul 2017 20:41:31 +0800 Subject: [PATCH 2/7] test device_context --- paddle/platform/CMakeLists.txt | 3 + paddle/platform/device_context.h | 166 +++++++++++++++++++++++++ paddle/platform/device_context_test.cu | 29 +++++ 3 files changed, 198 insertions(+) create mode 100644 paddle/platform/device_context.h create mode 100644 paddle/platform/device_context_test.cu diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index c7d7b14518..c95b54a4df 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -2,3 +2,6 @@ nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) + +cc_library(dynamic_loader SRCS dynamic_loader.cc) +nv_test(device_context_test SRCS device_context_test.cu DEPS place dynamic_loader glog gflags) diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h new file mode 100644 index 0000000000..f95aac4a36 --- /dev/null +++ b/paddle/platform/device_context.h @@ -0,0 +1,166 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifndef PADDLE_ONLY_CPU +#include "paddle/platform/cublas.h" +#include "paddle/platform/cuda.h" +#include "paddle/platform/cudnn.h" +#include "paddle/platform/curand.h" +#define EIGEN_USE_GPU +#endif + +#include "paddle/framework/enforce.h" +#include "paddle/platform/place.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace platform { + +class DeviceContext { + public: + virtual ~DeviceContext() {} +}; + +class CpuDeviceContext : public DeviceContext { + Eigen::DefaultDevice eigen_device() { + if (!eigen_device_) { + eigen_device_ = new Eigen::DefaultDevice(); + } + return *eigen_device_; + } + + private: + Eigen::DefaultDevice* eigen_device_{nullptr}; +}; + +#ifndef PADDLE_ONLY_CPU +class DeviceGuard { + public: + explicit DeviceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { + if (previous_ != new_place) { + paddle::platform::SetDeviceId(new_place.device); + } + } + + ~DeviceGuard() { paddle::platform::SetDeviceId(previous_.device); } + + private: + GPUPlace previous_; +}; + +class CudaDeviceContext : public DeviceContext { + public: + explicit CudaDeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { + DeviceGuard guard(gpu_place_); + paddle::platform::throw_on_error(cudaStreamCreate(&stream_), + "cudaStreamCreate failed"); + eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); + eigen_device_ = new Eigen::GpuDevice(eigen_stream_); + } + + void Wait() { + paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), + "cudaStreamSynchronize failed"); + } + + cudaStream_t stream() { return stream_; } + + Eigen::GpuDevice eigen_device() { return *eigen_device_; } + + cublasHandle_t cublas_handle() { + if (!blas_handle_) { + DeviceGuard guard(gpu_place_); + PADDLE_ENFORCE(cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS, + "cublasCreate failed"); + PADDLE_ENFORCE( + cublasSetStream(blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, + "cublasSetStream failed"); + } + return blas_handle_; + } + + cudnnHandle_t cudnn_handle() { + if (!dnn_handle_) { + DeviceGuard guard(gpu_place_); + PADDLE_ENFORCE(cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS, + "cudnnCreate failed"); + PADDLE_ENFORCE( + cudnnSetStream(dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, + "cudnnSetStream failed"); + } + return dnn_handle_; + } + + curandGenerator_t curand_generator() { + if (!rand_generator_) { + DeviceGuard guard(gpu_place_); + PADDLE_ENFORCE( + curandCreateGenerator(&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == + CURAND_STATUS_SUCCESS, + "curandCreateGenerator failed"); + PADDLE_ENFORCE( + curandSetPseudoRandomGeneratorSeed(rand_generator_, random_seed_) == + CURAND_STATUS_SUCCESS, + "curandSetPseudoRandomGeneratorSeed failed"); + PADDLE_ENFORCE( + curandSetStream(rand_generator_, stream_) == CURAND_STATUS_SUCCESS, + "curandSetStream failed"); + } + return rand_generator_; + } + + ~CudaDeviceContext() { + Wait(); + if (blas_handle_) { + PADDLE_ENFORCE(cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS, + "cublasDestroy failed"); + } + + if (dnn_handle_) { + PADDLE_ENFORCE(cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS, + "cudnnDestroy failed"); + } + + if (rand_generator_) { + PADDLE_ENFORCE( + curandDestroyGenerator(rand_generator_) == CURAND_STATUS_SUCCESS, + "curandDestroyGenerator failed"); + } + + delete eigen_stream_; + delete eigen_device_; + + paddle::platform::throw_on_error(cudaStreamDestroy(stream_), + "cudaStreamDestroy failed"); + } + + private: + GPUPlace gpu_place_; + cudaStream_t stream_; + + Eigen::CudaStreamDevice* eigen_stream_; + Eigen::GpuDevice* eigen_device_; + + cublasHandle_t blas_handle_{nullptr}; + + cudnnHandle_t dnn_handle_{nullptr}; + + int random_seed_; + curandGenerator_t rand_generator_{nullptr}; +}; +#endif +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/device_context_test.cu b/paddle/platform/device_context_test.cu new file mode 100644 index 0000000000..a15fb53b71 --- /dev/null +++ b/paddle/platform/device_context_test.cu @@ -0,0 +1,29 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/device_context.h" +#include "gtest/gtest.h" + + +TEST(DeviceContext, CudaDevice) { + int count = paddle::platform::GetDeviceCount(); + for (int i = 0; i < count; i++) { + paddle::platform::CudaDeviceContext* device_context = new paddle::platform::CudaDeviceContext(i); + __attribute__((unused)) Eigen::GpuDevice gpu_device = device_context->eigen_device(); + __attribute__((unused)) cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); + __attribute__((unused)) cublasHandle_t cublas_handle = device_context->cublas_handle(); + __attribute__((unused)) curandGenerator_t curand_handle = device_context->curand_generator(); + delete device_context; + } +} From a77fcef3f99724e85e2239ad91683b7afe913cd8 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 3 Jul 2017 12:55:39 +0000 Subject: [PATCH 3/7] fix cuda compile error --- paddle/platform/cublas.h | 3 -- paddle/platform/cuda.h | 9 ++++++ paddle/platform/curand.h | 5 ++- paddle/platform/device_context.h | 52 +++++++++++++++++-------------- paddle/platform/dynamic_loader.cc | 4 +-- 5 files changed, 43 insertions(+), 30 deletions(-) diff --git a/paddle/platform/cublas.h b/paddle/platform/cublas.h index 70c9713325..d60eb501e9 100644 --- a/paddle/platform/cublas.h +++ b/paddle/platform/cublas.h @@ -3,7 +3,6 @@ namespace paddle { namespace dyload { -namespace dynload { std::once_flag cublas_dso_flag; void *cublas_dso_handle = nullptr; @@ -67,8 +66,6 @@ CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP) #undef DYNAMIC_LOAD_CUBLAS_V2_WRAP #undef CUBLAS_BLAS_ROUTINE_EACH -} /* namespace dynload */ - // clang-format on #ifndef PADDLE_TYPE_DOUBLE #define CUBLAS_GEAM dynload::cublasSgeam diff --git a/paddle/platform/cuda.h b/paddle/platform/cuda.h index 8fe891f9ce..05290b0e1e 100644 --- a/paddle/platform/cuda.h +++ b/paddle/platform/cuda.h @@ -33,6 +33,15 @@ int GetDeviceCount(void) { throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed"); return count; } +int GetCurrentDeviceId(void) { + int device_id; + throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed"); + return device_id; +} + +void SetDeviceId(int device_id) { + throw_on_error(cudaSetDevice(device_id), "cudaSetDevice failed"); +} } // namespace platform } // namespace paddle diff --git a/paddle/platform/curand.h b/paddle/platform/curand.h index 692c024e6e..edff6526bd 100644 --- a/paddle/platform/curand.h +++ b/paddle/platform/curand.h @@ -3,6 +3,8 @@ namespace paddle { namespace dyload { +std::once_flag curand_dso_flag; +void *curand_dso_handle = nullptr; #ifdef PADDLE_USE_DSO #define DYNAMIC_LOAD_CURAND_WRAP(__name) \ struct DynLoad__##__name { \ @@ -31,7 +33,8 @@ namespace dyload { __macro(curandSetStream) \ __macro(curandSetPseudoRandomGeneratorSeed)\ __macro(curandGenerateUniform) \ - __macro(curandGenerateUniformDouble) + __macro(curandGenerateUniformDouble) \ + __macro(curandDestroyGenerator) // clang-format on CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP) diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index f95aac4a36..65e76666a7 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -83,11 +83,12 @@ class CudaDeviceContext : public DeviceContext { cublasHandle_t cublas_handle() { if (!blas_handle_) { DeviceGuard guard(gpu_place_); - PADDLE_ENFORCE(cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS, - "cublasCreate failed"); PADDLE_ENFORCE( - cublasSetStream(blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, - "cublasSetStream failed"); + paddle::dyload::cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS, + "cublasCreate failed"); + PADDLE_ENFORCE(paddle::dyload::cublasSetStream(blas_handle_, stream_) == + CUBLAS_STATUS_SUCCESS, + "cublasSetStream failed"); } return blas_handle_; } @@ -95,11 +96,12 @@ class CudaDeviceContext : public DeviceContext { cudnnHandle_t cudnn_handle() { if (!dnn_handle_) { DeviceGuard guard(gpu_place_); - PADDLE_ENFORCE(cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS, - "cudnnCreate failed"); PADDLE_ENFORCE( - cudnnSetStream(dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, - "cudnnSetStream failed"); + paddle::dyload::cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS, + "cudnnCreate failed"); + PADDLE_ENFORCE(paddle::dyload::cudnnSetStream(dnn_handle_, stream_) == + CUDNN_STATUS_SUCCESS, + "cudnnSetStream failed"); } return dnn_handle_; } @@ -107,17 +109,17 @@ class CudaDeviceContext : public DeviceContext { curandGenerator_t curand_generator() { if (!rand_generator_) { DeviceGuard guard(gpu_place_); + PADDLE_ENFORCE(paddle::dyload::curandCreateGenerator( + &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == + CURAND_STATUS_SUCCESS, + "curandCreateGenerator failed"); PADDLE_ENFORCE( - curandCreateGenerator(&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == - CURAND_STATUS_SUCCESS, - "curandCreateGenerator failed"); - PADDLE_ENFORCE( - curandSetPseudoRandomGeneratorSeed(rand_generator_, random_seed_) == - CURAND_STATUS_SUCCESS, + paddle::dyload::curandSetPseudoRandomGeneratorSeed( + rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS, "curandSetPseudoRandomGeneratorSeed failed"); - PADDLE_ENFORCE( - curandSetStream(rand_generator_, stream_) == CURAND_STATUS_SUCCESS, - "curandSetStream failed"); + PADDLE_ENFORCE(paddle::dyload::curandSetStream( + rand_generator_, stream_) == CURAND_STATUS_SUCCESS, + "curandSetStream failed"); } return rand_generator_; } @@ -125,19 +127,21 @@ class CudaDeviceContext : public DeviceContext { ~CudaDeviceContext() { Wait(); if (blas_handle_) { - PADDLE_ENFORCE(cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS, - "cublasDestroy failed"); + PADDLE_ENFORCE( + paddle::dyload::cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS, + "cublasDestroy failed"); } if (dnn_handle_) { - PADDLE_ENFORCE(cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS, - "cudnnDestroy failed"); + PADDLE_ENFORCE( + paddle::dyload::cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS, + "cudnnDestroy failed"); } if (rand_generator_) { - PADDLE_ENFORCE( - curandDestroyGenerator(rand_generator_) == CURAND_STATUS_SUCCESS, - "curandDestroyGenerator failed"); + PADDLE_ENFORCE(paddle::dyload::curandDestroyGenerator(rand_generator_) == + CURAND_STATUS_SUCCESS, + "curandDestroyGenerator failed"); } delete eigen_stream_; diff --git a/paddle/platform/dynamic_loader.cc b/paddle/platform/dynamic_loader.cc index 9036eaf642..c34abc392c 100644 --- a/paddle/platform/dynamic_loader.cc +++ b/paddle/platform/dynamic_loader.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "dynamic_loader.h" #include -#include "DynamicLoader.h" -#include "Logging.h" +#include DEFINE_string(cudnn_dir, "", "Specify path for loading libcudnn.so. For instance, " From ed18647e37f4e345f02171f29af6e22fab4790ea Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 4 Jul 2017 11:00:59 +0800 Subject: [PATCH 4/7] finish test --- paddle/platform/CMakeLists.txt | 1 - paddle/platform/cuda.h | 1 + paddle/platform/device_context.h | 170 ------------------------- paddle/platform/device_context_test.cu | 29 ----- 4 files changed, 1 insertion(+), 200 deletions(-) delete mode 100644 paddle/platform/device_context.h delete mode 100644 paddle/platform/device_context_test.cu diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index c95b54a4df..ffdc23d599 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -4,4 +4,3 @@ cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) cc_library(dynamic_loader SRCS dynamic_loader.cc) -nv_test(device_context_test SRCS device_context_test.cu DEPS place dynamic_loader glog gflags) diff --git a/paddle/platform/cuda.h b/paddle/platform/cuda.h index 05290b0e1e..5ed36c0f02 100644 --- a/paddle/platform/cuda.h +++ b/paddle/platform/cuda.h @@ -33,6 +33,7 @@ int GetDeviceCount(void) { throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed"); return count; } + int GetCurrentDeviceId(void) { int device_id; throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed"); diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h deleted file mode 100644 index 65e76666a7..0000000000 --- a/paddle/platform/device_context.h +++ /dev/null @@ -1,170 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#ifndef PADDLE_ONLY_CPU -#include "paddle/platform/cublas.h" -#include "paddle/platform/cuda.h" -#include "paddle/platform/cudnn.h" -#include "paddle/platform/curand.h" -#define EIGEN_USE_GPU -#endif - -#include "paddle/framework/enforce.h" -#include "paddle/platform/place.h" -#include "unsupported/Eigen/CXX11/Tensor" - -namespace paddle { -namespace platform { - -class DeviceContext { - public: - virtual ~DeviceContext() {} -}; - -class CpuDeviceContext : public DeviceContext { - Eigen::DefaultDevice eigen_device() { - if (!eigen_device_) { - eigen_device_ = new Eigen::DefaultDevice(); - } - return *eigen_device_; - } - - private: - Eigen::DefaultDevice* eigen_device_{nullptr}; -}; - -#ifndef PADDLE_ONLY_CPU -class DeviceGuard { - public: - explicit DeviceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { - if (previous_ != new_place) { - paddle::platform::SetDeviceId(new_place.device); - } - } - - ~DeviceGuard() { paddle::platform::SetDeviceId(previous_.device); } - - private: - GPUPlace previous_; -}; - -class CudaDeviceContext : public DeviceContext { - public: - explicit CudaDeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { - DeviceGuard guard(gpu_place_); - paddle::platform::throw_on_error(cudaStreamCreate(&stream_), - "cudaStreamCreate failed"); - eigen_stream_ = new Eigen::CudaStreamDevice(&stream_); - eigen_device_ = new Eigen::GpuDevice(eigen_stream_); - } - - void Wait() { - paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), - "cudaStreamSynchronize failed"); - } - - cudaStream_t stream() { return stream_; } - - Eigen::GpuDevice eigen_device() { return *eigen_device_; } - - cublasHandle_t cublas_handle() { - if (!blas_handle_) { - DeviceGuard guard(gpu_place_); - PADDLE_ENFORCE( - paddle::dyload::cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS, - "cublasCreate failed"); - PADDLE_ENFORCE(paddle::dyload::cublasSetStream(blas_handle_, stream_) == - CUBLAS_STATUS_SUCCESS, - "cublasSetStream failed"); - } - return blas_handle_; - } - - cudnnHandle_t cudnn_handle() { - if (!dnn_handle_) { - DeviceGuard guard(gpu_place_); - PADDLE_ENFORCE( - paddle::dyload::cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS, - "cudnnCreate failed"); - PADDLE_ENFORCE(paddle::dyload::cudnnSetStream(dnn_handle_, stream_) == - CUDNN_STATUS_SUCCESS, - "cudnnSetStream failed"); - } - return dnn_handle_; - } - - curandGenerator_t curand_generator() { - if (!rand_generator_) { - DeviceGuard guard(gpu_place_); - PADDLE_ENFORCE(paddle::dyload::curandCreateGenerator( - &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == - CURAND_STATUS_SUCCESS, - "curandCreateGenerator failed"); - PADDLE_ENFORCE( - paddle::dyload::curandSetPseudoRandomGeneratorSeed( - rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS, - "curandSetPseudoRandomGeneratorSeed failed"); - PADDLE_ENFORCE(paddle::dyload::curandSetStream( - rand_generator_, stream_) == CURAND_STATUS_SUCCESS, - "curandSetStream failed"); - } - return rand_generator_; - } - - ~CudaDeviceContext() { - Wait(); - if (blas_handle_) { - PADDLE_ENFORCE( - paddle::dyload::cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS, - "cublasDestroy failed"); - } - - if (dnn_handle_) { - PADDLE_ENFORCE( - paddle::dyload::cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS, - "cudnnDestroy failed"); - } - - if (rand_generator_) { - PADDLE_ENFORCE(paddle::dyload::curandDestroyGenerator(rand_generator_) == - CURAND_STATUS_SUCCESS, - "curandDestroyGenerator failed"); - } - - delete eigen_stream_; - delete eigen_device_; - - paddle::platform::throw_on_error(cudaStreamDestroy(stream_), - "cudaStreamDestroy failed"); - } - - private: - GPUPlace gpu_place_; - cudaStream_t stream_; - - Eigen::CudaStreamDevice* eigen_stream_; - Eigen::GpuDevice* eigen_device_; - - cublasHandle_t blas_handle_{nullptr}; - - cudnnHandle_t dnn_handle_{nullptr}; - - int random_seed_; - curandGenerator_t rand_generator_{nullptr}; -}; -#endif -} // namespace platform -} // namespace paddle diff --git a/paddle/platform/device_context_test.cu b/paddle/platform/device_context_test.cu deleted file mode 100644 index a15fb53b71..0000000000 --- a/paddle/platform/device_context_test.cu +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/platform/device_context.h" -#include "gtest/gtest.h" - - -TEST(DeviceContext, CudaDevice) { - int count = paddle::platform::GetDeviceCount(); - for (int i = 0; i < count; i++) { - paddle::platform::CudaDeviceContext* device_context = new paddle::platform::CudaDeviceContext(i); - __attribute__((unused)) Eigen::GpuDevice gpu_device = device_context->eigen_device(); - __attribute__((unused)) cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); - __attribute__((unused)) cublasHandle_t cublas_handle = device_context->cublas_handle(); - __attribute__((unused)) curandGenerator_t curand_handle = device_context->curand_generator(); - delete device_context; - } -} From 76b7be46da5fe211d25e62712673cc01bea98d54 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 4 Jul 2017 11:16:49 +0800 Subject: [PATCH 5/7] add deps for dyload cc_library --- paddle/platform/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index ffdc23d599..4f6381b8af 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -3,4 +3,4 @@ nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) -cc_library(dynamic_loader SRCS dynamic_loader.cc) +cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) From 9eeabe986d039b3fe3b28e5ef98f66d6dd2a3e31 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 4 Jul 2017 14:03:58 +0800 Subject: [PATCH 6/7] follow comments --- paddle/platform/cublas.h | 58 +++++++++++++++++++++---------- paddle/platform/cudnn.h | 38 +++++++++++++++----- paddle/platform/curand.h | 40 +++++++++++++++------ paddle/platform/dynamic_loader.cc | 16 +++++++-- paddle/platform/dynamic_loader.h | 14 ++++---- 5 files changed, 119 insertions(+), 47 deletions(-) diff --git a/paddle/platform/cublas.h b/paddle/platform/cublas.h index d60eb501e9..90704f37e6 100644 --- a/paddle/platform/cublas.h +++ b/paddle/platform/cublas.h @@ -1,7 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + #include #include "paddle/platform/dynamic_loader.h" namespace paddle { +namespace platform { namespace dyload { std::once_flag cublas_dso_flag; @@ -15,15 +32,17 @@ void *cublas_dso_handle = nullptr; * note: default dynamic linked libs */ #ifdef PADDLE_USE_DSO -#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - cublasStatus_t operator()(Args... args) { \ - typedef cublasStatus_t (*cublasFunc)(Args...); \ - std::call_once(cublas_dso_flag, GetCublasDsoHandle, &cublas_dso_handle); \ - void *p_##__name = dlsym(cublas_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ +#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + cublasStatus_t operator()(Args... args) { \ + typedef cublasStatus_t (*cublasFunc)(Args...); \ + std::call_once(cublas_dso_flag, \ + paddle::platform::dyload::GetCublasDsoHandle, \ + &cublas_dso_handle); \ + void *p_##__name = dlsym(cublas_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ } __name; // struct DynLoad__##__name #else #define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ @@ -68,17 +87,18 @@ CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP) // clang-format on #ifndef PADDLE_TYPE_DOUBLE -#define CUBLAS_GEAM dynload::cublasSgeam -#define CUBLAS_GEMV dynload::cublasSgemv -#define CUBLAS_GEMM dynload::cublasSgemm -#define CUBLAS_GETRF dynload::cublasSgetrfBatched -#define CUBLAS_GETRI dynload::cublasSgetriBatched +#define CUBLAS_GEAM paddle::platform::dynload::cublasSgeam +#define CUBLAS_GEMV paddle::platform::dynload::cublasSgemv +#define CUBLAS_GEMM paddle::platform::dynload::cublasSgemm +#define CUBLAS_GETRF paddle::platform::dynload::cublasSgetrfBatched +#define CUBLAS_GETRI paddle::platform::dynload::cublasSgetriBatched #else -#define CUBLAS_GEAM dynload::cublasDgeam -#define CUBLAS_GEMV dynload::cublasDgemv -#define CUBLAS_GEMM dynload::cublasDgemm -#define CUBLAS_GETRF dynload::cublasDgetrfBatched -#define CUBLAS_GETRI dynload::cublasDgetriBatched +#define CUBLAS_GEAM paddle::platform::dynload::cublasDgeam +#define CUBLAS_GEMV paddle::platform::dynload::cublasDgemv +#define CUBLAS_GEMM paddle::platform::dynload::cublasDgemm +#define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched +#define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched #endif } // namespace dyload +} // namespace platform } // namespace paddle diff --git a/paddle/platform/cudnn.h b/paddle/platform/cudnn.h index ab878cd555..06e2a05d86 100644 --- a/paddle/platform/cudnn.h +++ b/paddle/platform/cudnn.h @@ -1,7 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + #include #include "paddle/platform/dynamic_loader.h" namespace paddle { +namespace platform { namespace dyload { std::once_flag cudnn_dso_flag; @@ -9,15 +26,17 @@ void* cudnn_dso_handle = nullptr; #ifdef PADDLE_USE_DSO -#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - auto operator()(Args... args) -> decltype(__name(args...)) { \ - using cudnn_func = decltype(__name(args...)) (*)(Args...); \ - std::call_once(cudnn_dso_flag, GetCudnnDsoHandle, &cudnn_dso_handle); \ - void* p_##__name = dlsym(cudnn_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ +#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + using cudnn_func = decltype(__name(args...)) (*)(Args...); \ + std::call_once(cudnn_dso_flag, \ + paddle::platform::dyload::GetCudnnDsoHandle, \ + &cudnn_dso_handle); \ + void* p_##__name = dlsym(cudnn_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ } __name; /* struct DynLoad__##__name */ #else @@ -111,4 +130,5 @@ CUDNN_DNN_ROUTINE_EACH_R5(DYNAMIC_LOAD_CUDNN_WRAP) #undef CUDNN_DNN_ROUTINE_EACH // clang-format on } // namespace dyload +} // namespace platform } // namespace paddle diff --git a/paddle/platform/curand.h b/paddle/platform/curand.h index edff6526bd..a9cbe48ef8 100644 --- a/paddle/platform/curand.h +++ b/paddle/platform/curand.h @@ -1,20 +1,39 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + #include #include "paddle/platform/dynamic_loader.h" namespace paddle { +namespace platform { namespace dyload { std::once_flag curand_dso_flag; void *curand_dso_handle = nullptr; #ifdef PADDLE_USE_DSO -#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - curandStatus_t operator()(Args... args) { \ - typedef curandStatus_t (*curandFunc)(Args...); \ - std::call_once(curand_dso_flag, GetCurandDsoHandle, &curand_dso_handle); \ - void *p_##__name = dlsym(curand_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ +#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + curandStatus_t operator()(Args... args) { \ + typedef curandStatus_t (*curandFunc)(Args...); \ + std::call_once(curand_dso_flag, \ + paddle::platform::dyload::GetCurandDsoHandle, \ + &curand_dso_handle); \ + void *p_##__name = dlsym(curand_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ } __name; /* struct DynLoad__##__name */ #else #define DYNAMIC_LOAD_CURAND_WRAP(__name) \ @@ -41,5 +60,6 @@ CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP) #undef CURAND_RAND_ROUTINE_EACH #undef DYNAMIC_LOAD_CURAND_WRAP -} +} // namespace dyload +} // namespace platform } // namespace paddle diff --git a/paddle/platform/dynamic_loader.cc b/paddle/platform/dynamic_loader.cc index c34abc392c..9e0aadf8e2 100644 --- a/paddle/platform/dynamic_loader.cc +++ b/paddle/platform/dynamic_loader.cc @@ -13,8 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "dynamic_loader.h" -#include -#include +#include +#include +#include +#include +#include "gflags/gflags.h" +#include "glog/logging.h" DEFINE_string(cudnn_dir, "", "Specify path for loading libcudnn.so. For instance, " @@ -30,6 +34,10 @@ DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so."); DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so."); +namespace paddle { +namespace platform { +namespace dyload { + static inline std::string join(const std::string& part1, const std::string& part2) { // directory separator @@ -155,3 +163,7 @@ void GetLapackDsoHandle(void** dso_handle) { GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.so", dso_handle); #endif } + +} // namespace dyload +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/dynamic_loader.h b/paddle/platform/dynamic_loader.h index 9b5ad21724..bb58fcba17 100644 --- a/paddle/platform/dynamic_loader.h +++ b/paddle/platform/dynamic_loader.h @@ -12,13 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef DYNAMIC_LOAD_H_ -#define DYNAMIC_LOAD_H_ +#pragma once -#include -#include -#include -#include +namespace paddle { +namespace platform { +namespace dyload { /** * @brief load the DSO of CUBLAS @@ -60,4 +58,6 @@ void GetWarpCTCDsoHandle(void** dso_handle); */ void GetLapackDsoHandle(void** dso_handle); -#endif // DYNAMIC_LOAD_H_ +} // namespace dyload +} // namespace platform +} // namespace paddle From 3567ea6d7c85bdfdd42f57f6db98e5bc9bbac5e2 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 4 Jul 2017 14:58:20 +0800 Subject: [PATCH 7/7] move to dynload directory --- paddle/platform/CMakeLists.txt | 4 +-- paddle/platform/dynload/CMakeLists.txt | 1 + paddle/platform/{ => dynload}/cublas.h | 26 +++++++++---------- paddle/platform/{ => dynload}/cudnn.h | 26 +++++++++---------- paddle/platform/{ => dynload}/curand.h | 26 +++++++++---------- .../platform/{ => dynload}/dynamic_loader.cc | 4 +-- .../platform/{ => dynload}/dynamic_loader.h | 4 +-- 7 files changed, 46 insertions(+), 45 deletions(-) create mode 100644 paddle/platform/dynload/CMakeLists.txt rename paddle/platform/{ => dynload}/cublas.h (95%) rename paddle/platform/{ => dynload}/cudnn.h (97%) rename paddle/platform/{ => dynload}/curand.h (93%) rename paddle/platform/{ => dynload}/dynamic_loader.cc (99%) rename paddle/platform/{ => dynload}/dynamic_loader.h (96%) diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 4f6381b8af..cc6b52e927 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -1,6 +1,6 @@ +add_subdirectory(dynload) + nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) - -cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) diff --git a/paddle/platform/dynload/CMakeLists.txt b/paddle/platform/dynload/CMakeLists.txt new file mode 100644 index 0000000000..9f829b7012 --- /dev/null +++ b/paddle/platform/dynload/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) diff --git a/paddle/platform/cublas.h b/paddle/platform/dynload/cublas.h similarity index 95% rename from paddle/platform/cublas.h rename to paddle/platform/dynload/cublas.h index 90704f37e6..c9150ac573 100644 --- a/paddle/platform/cublas.h +++ b/paddle/platform/dynload/cublas.h @@ -19,7 +19,7 @@ limitations under the License. */ namespace paddle { namespace platform { -namespace dyload { +namespace dynload { std::once_flag cublas_dso_flag; void *cublas_dso_handle = nullptr; @@ -32,17 +32,17 @@ void *cublas_dso_handle = nullptr; * note: default dynamic linked libs */ #ifdef PADDLE_USE_DSO -#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - cublasStatus_t operator()(Args... args) { \ - typedef cublasStatus_t (*cublasFunc)(Args...); \ - std::call_once(cublas_dso_flag, \ - paddle::platform::dyload::GetCublasDsoHandle, \ - &cublas_dso_handle); \ - void *p_##__name = dlsym(cublas_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ +#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + cublasStatus_t operator()(Args... args) { \ + typedef cublasStatus_t (*cublasFunc)(Args...); \ + std::call_once(cublas_dso_flag, \ + paddle::platform::dynload::GetCublasDsoHandle, \ + &cublas_dso_handle); \ + void *p_##__name = dlsym(cublas_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ } __name; // struct DynLoad__##__name #else #define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \ @@ -99,6 +99,6 @@ CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP) #define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched #define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched #endif -} // namespace dyload +} // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/platform/cudnn.h b/paddle/platform/dynload/cudnn.h similarity index 97% rename from paddle/platform/cudnn.h rename to paddle/platform/dynload/cudnn.h index 06e2a05d86..c03424b375 100644 --- a/paddle/platform/cudnn.h +++ b/paddle/platform/dynload/cudnn.h @@ -19,24 +19,24 @@ limitations under the License. */ namespace paddle { namespace platform { -namespace dyload { +namespace dynload { std::once_flag cudnn_dso_flag; void* cudnn_dso_handle = nullptr; #ifdef PADDLE_USE_DSO -#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - auto operator()(Args... args) -> decltype(__name(args...)) { \ - using cudnn_func = decltype(__name(args...)) (*)(Args...); \ - std::call_once(cudnn_dso_flag, \ - paddle::platform::dyload::GetCudnnDsoHandle, \ - &cudnn_dso_handle); \ - void* p_##__name = dlsym(cudnn_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ +#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + using cudnn_func = decltype(__name(args...)) (*)(Args...); \ + std::call_once(cudnn_dso_flag, \ + paddle::platform::dynload::GetCudnnDsoHandle, \ + &cudnn_dso_handle); \ + void* p_##__name = dlsym(cudnn_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ } __name; /* struct DynLoad__##__name */ #else @@ -129,6 +129,6 @@ CUDNN_DNN_ROUTINE_EACH_R5(DYNAMIC_LOAD_CUDNN_WRAP) #undef CUDNN_DNN_ROUTINE_EACH // clang-format on -} // namespace dyload +} // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/platform/curand.h b/paddle/platform/dynload/curand.h similarity index 93% rename from paddle/platform/curand.h rename to paddle/platform/dynload/curand.h index a9cbe48ef8..1ef7a8c833 100644 --- a/paddle/platform/curand.h +++ b/paddle/platform/dynload/curand.h @@ -19,21 +19,21 @@ limitations under the License. */ namespace paddle { namespace platform { -namespace dyload { +namespace dynload { std::once_flag curand_dso_flag; void *curand_dso_handle = nullptr; #ifdef PADDLE_USE_DSO -#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ - struct DynLoad__##__name { \ - template \ - curandStatus_t operator()(Args... args) { \ - typedef curandStatus_t (*curandFunc)(Args...); \ - std::call_once(curand_dso_flag, \ - paddle::platform::dyload::GetCurandDsoHandle, \ - &curand_dso_handle); \ - void *p_##__name = dlsym(curand_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ - } \ +#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + curandStatus_t operator()(Args... args) { \ + typedef curandStatus_t (*curandFunc)(Args...); \ + std::call_once(curand_dso_flag, \ + paddle::platform::dynload::GetCurandDsoHandle, \ + &curand_dso_handle); \ + void *p_##__name = dlsym(curand_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ } __name; /* struct DynLoad__##__name */ #else #define DYNAMIC_LOAD_CURAND_WRAP(__name) \ @@ -60,6 +60,6 @@ CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP) #undef CURAND_RAND_ROUTINE_EACH #undef DYNAMIC_LOAD_CURAND_WRAP -} // namespace dyload +} // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/platform/dynamic_loader.cc b/paddle/platform/dynload/dynamic_loader.cc similarity index 99% rename from paddle/platform/dynamic_loader.cc rename to paddle/platform/dynload/dynamic_loader.cc index 9e0aadf8e2..8ef67bad8c 100644 --- a/paddle/platform/dynamic_loader.cc +++ b/paddle/platform/dynload/dynamic_loader.cc @@ -36,7 +36,7 @@ DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so."); namespace paddle { namespace platform { -namespace dyload { +namespace dynload { static inline std::string join(const std::string& part1, const std::string& part2) { @@ -164,6 +164,6 @@ void GetLapackDsoHandle(void** dso_handle) { #endif } -} // namespace dyload +} // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/platform/dynamic_loader.h b/paddle/platform/dynload/dynamic_loader.h similarity index 96% rename from paddle/platform/dynamic_loader.h rename to paddle/platform/dynload/dynamic_loader.h index bb58fcba17..a99b05443f 100644 --- a/paddle/platform/dynamic_loader.h +++ b/paddle/platform/dynload/dynamic_loader.h @@ -16,7 +16,7 @@ limitations under the License. */ namespace paddle { namespace platform { -namespace dyload { +namespace dynload { /** * @brief load the DSO of CUBLAS @@ -58,6 +58,6 @@ void GetWarpCTCDsoHandle(void** dso_handle); */ void GetLapackDsoHandle(void** dso_handle); -} // namespace dyload +} // namespace dynload } // namespace platform } // namespace paddle