|
|
|
@ -1,7 +1,24 @@
|
|
|
|
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <cublas_v2.h>
|
|
|
|
|
#include "paddle/platform/dynamic_loader.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace platform {
|
|
|
|
|
namespace dyload {
|
|
|
|
|
|
|
|
|
|
std::once_flag cublas_dso_flag;
|
|
|
|
@ -15,15 +32,17 @@ void *cublas_dso_handle = nullptr;
|
|
|
|
|
* note: default dynamic linked libs
|
|
|
|
|
*/
|
|
|
|
|
#ifdef PADDLE_USE_DSO
|
|
|
|
|
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
|
|
|
|
|
struct DynLoad__##__name { \
|
|
|
|
|
template <typename... Args> \
|
|
|
|
|
cublasStatus_t operator()(Args... args) { \
|
|
|
|
|
typedef cublasStatus_t (*cublasFunc)(Args...); \
|
|
|
|
|
std::call_once(cublas_dso_flag, GetCublasDsoHandle, &cublas_dso_handle); \
|
|
|
|
|
void *p_##__name = dlsym(cublas_dso_handle, #__name); \
|
|
|
|
|
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
|
|
|
|
|
} \
|
|
|
|
|
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
|
|
|
|
|
struct DynLoad__##__name { \
|
|
|
|
|
template <typename... Args> \
|
|
|
|
|
cublasStatus_t operator()(Args... args) { \
|
|
|
|
|
typedef cublasStatus_t (*cublasFunc)(Args...); \
|
|
|
|
|
std::call_once(cublas_dso_flag, \
|
|
|
|
|
paddle::platform::dyload::GetCublasDsoHandle, \
|
|
|
|
|
&cublas_dso_handle); \
|
|
|
|
|
void *p_##__name = dlsym(cublas_dso_handle, #__name); \
|
|
|
|
|
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
|
|
|
|
|
} \
|
|
|
|
|
} __name; // struct DynLoad__##__name
|
|
|
|
|
#else
|
|
|
|
|
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
|
|
|
|
@ -68,17 +87,18 @@ CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)
|
|
|
|
|
|
|
|
|
|
// clang-format on
|
|
|
|
|
#ifndef PADDLE_TYPE_DOUBLE
|
|
|
|
|
#define CUBLAS_GEAM dynload::cublasSgeam
|
|
|
|
|
#define CUBLAS_GEMV dynload::cublasSgemv
|
|
|
|
|
#define CUBLAS_GEMM dynload::cublasSgemm
|
|
|
|
|
#define CUBLAS_GETRF dynload::cublasSgetrfBatched
|
|
|
|
|
#define CUBLAS_GETRI dynload::cublasSgetriBatched
|
|
|
|
|
#define CUBLAS_GEAM paddle::platform::dynload::cublasSgeam
|
|
|
|
|
#define CUBLAS_GEMV paddle::platform::dynload::cublasSgemv
|
|
|
|
|
#define CUBLAS_GEMM paddle::platform::dynload::cublasSgemm
|
|
|
|
|
#define CUBLAS_GETRF paddle::platform::dynload::cublasSgetrfBatched
|
|
|
|
|
#define CUBLAS_GETRI paddle::platform::dynload::cublasSgetriBatched
|
|
|
|
|
#else
|
|
|
|
|
#define CUBLAS_GEAM dynload::cublasDgeam
|
|
|
|
|
#define CUBLAS_GEMV dynload::cublasDgemv
|
|
|
|
|
#define CUBLAS_GEMM dynload::cublasDgemm
|
|
|
|
|
#define CUBLAS_GETRF dynload::cublasDgetrfBatched
|
|
|
|
|
#define CUBLAS_GETRI dynload::cublasDgetriBatched
|
|
|
|
|
#define CUBLAS_GEAM paddle::platform::dynload::cublasDgeam
|
|
|
|
|
#define CUBLAS_GEMV paddle::platform::dynload::cublasDgemv
|
|
|
|
|
#define CUBLAS_GEMM paddle::platform::dynload::cublasDgemm
|
|
|
|
|
#define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched
|
|
|
|
|
#define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched
|
|
|
|
|
#endif
|
|
|
|
|
} // namespace dyload
|
|
|
|
|
} // namespace platform
|
|
|
|
|
} // namespace paddle
|
|
|
|
|