fix cuda compile error

gangliao-patch-1
qijun 8 years ago
parent a30754b05e
commit a77fcef3f9

@ -3,7 +3,6 @@
namespace paddle { namespace paddle {
namespace dyload { namespace dyload {
namespace dynload {
std::once_flag cublas_dso_flag; std::once_flag cublas_dso_flag;
void *cublas_dso_handle = nullptr; void *cublas_dso_handle = nullptr;
@ -67,8 +66,6 @@ CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)
#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP #undef DYNAMIC_LOAD_CUBLAS_V2_WRAP
#undef CUBLAS_BLAS_ROUTINE_EACH #undef CUBLAS_BLAS_ROUTINE_EACH
} /* namespace dynload */
// clang-format on // clang-format on
#ifndef PADDLE_TYPE_DOUBLE #ifndef PADDLE_TYPE_DOUBLE
#define CUBLAS_GEAM dynload::cublasSgeam #define CUBLAS_GEAM dynload::cublasSgeam

@ -33,6 +33,15 @@ int GetDeviceCount(void) {
throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed"); throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed");
return count; return count;
} }
int GetCurrentDeviceId(void) {
int device_id;
throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed");
return device_id;
}
void SetDeviceId(int device_id) {
throw_on_error(cudaSetDevice(device_id), "cudaSetDevice failed");
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle

@ -3,6 +3,8 @@
namespace paddle { namespace paddle {
namespace dyload { namespace dyload {
std::once_flag curand_dso_flag;
void *curand_dso_handle = nullptr;
#ifdef PADDLE_USE_DSO #ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CURAND_WRAP(__name) \ #define DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
@ -31,7 +33,8 @@ namespace dyload {
__macro(curandSetStream) \ __macro(curandSetStream) \
__macro(curandSetPseudoRandomGeneratorSeed)\ __macro(curandSetPseudoRandomGeneratorSeed)\
__macro(curandGenerateUniform) \ __macro(curandGenerateUniform) \
__macro(curandGenerateUniformDouble) __macro(curandGenerateUniformDouble) \
__macro(curandDestroyGenerator)
// clang-format on // clang-format on
CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP) CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP)

@ -83,11 +83,12 @@ class CudaDeviceContext : public DeviceContext {
cublasHandle_t cublas_handle() { cublasHandle_t cublas_handle() {
if (!blas_handle_) { if (!blas_handle_) {
DeviceGuard guard(gpu_place_); DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS,
"cublasCreate failed");
PADDLE_ENFORCE( PADDLE_ENFORCE(
cublasSetStream(blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, paddle::dyload::cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS,
"cublasSetStream failed"); "cublasCreate failed");
PADDLE_ENFORCE(paddle::dyload::cublasSetStream(blas_handle_, stream_) ==
CUBLAS_STATUS_SUCCESS,
"cublasSetStream failed");
} }
return blas_handle_; return blas_handle_;
} }
@ -95,11 +96,12 @@ class CudaDeviceContext : public DeviceContext {
cudnnHandle_t cudnn_handle() { cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) { if (!dnn_handle_) {
DeviceGuard guard(gpu_place_); DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS,
"cudnnCreate failed");
PADDLE_ENFORCE( PADDLE_ENFORCE(
cudnnSetStream(dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, paddle::dyload::cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS,
"cudnnSetStream failed"); "cudnnCreate failed");
PADDLE_ENFORCE(paddle::dyload::cudnnSetStream(dnn_handle_, stream_) ==
CUDNN_STATUS_SUCCESS,
"cudnnSetStream failed");
} }
return dnn_handle_; return dnn_handle_;
} }
@ -107,17 +109,17 @@ class CudaDeviceContext : public DeviceContext {
curandGenerator_t curand_generator() { curandGenerator_t curand_generator() {
if (!rand_generator_) { if (!rand_generator_) {
DeviceGuard guard(gpu_place_); DeviceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::dyload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
CURAND_STATUS_SUCCESS,
"curandCreateGenerator failed");
PADDLE_ENFORCE( PADDLE_ENFORCE(
curandCreateGenerator(&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == paddle::dyload::curandSetPseudoRandomGeneratorSeed(
CURAND_STATUS_SUCCESS, rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS,
"curandCreateGenerator failed");
PADDLE_ENFORCE(
curandSetPseudoRandomGeneratorSeed(rand_generator_, random_seed_) ==
CURAND_STATUS_SUCCESS,
"curandSetPseudoRandomGeneratorSeed failed"); "curandSetPseudoRandomGeneratorSeed failed");
PADDLE_ENFORCE( PADDLE_ENFORCE(paddle::dyload::curandSetStream(
curandSetStream(rand_generator_, stream_) == CURAND_STATUS_SUCCESS, rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
"curandSetStream failed"); "curandSetStream failed");
} }
return rand_generator_; return rand_generator_;
} }
@ -125,19 +127,21 @@ class CudaDeviceContext : public DeviceContext {
~CudaDeviceContext() { ~CudaDeviceContext() {
Wait(); Wait();
if (blas_handle_) { if (blas_handle_) {
PADDLE_ENFORCE(cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS, PADDLE_ENFORCE(
"cublasDestroy failed"); paddle::dyload::cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS,
"cublasDestroy failed");
} }
if (dnn_handle_) { if (dnn_handle_) {
PADDLE_ENFORCE(cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS, PADDLE_ENFORCE(
"cudnnDestroy failed"); paddle::dyload::cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS,
"cudnnDestroy failed");
} }
if (rand_generator_) { if (rand_generator_) {
PADDLE_ENFORCE( PADDLE_ENFORCE(paddle::dyload::curandDestroyGenerator(rand_generator_) ==
curandDestroyGenerator(rand_generator_) == CURAND_STATUS_SUCCESS, CURAND_STATUS_SUCCESS,
"curandDestroyGenerator failed"); "curandDestroyGenerator failed");
} }
delete eigen_stream_; delete eigen_stream_;

@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "dynamic_loader.h"
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include "DynamicLoader.h" #include <glog/logging.h>
#include "Logging.h"
DEFINE_string(cudnn_dir, "", DEFINE_string(cudnn_dir, "",
"Specify path for loading libcudnn.so. For instance, " "Specify path for loading libcudnn.so. For instance, "

Loading…
Cancel
Save