|
|
|
@ -19,7 +19,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace platform {
|
|
|
|
|
namespace dyload {
|
|
|
|
|
namespace dynload {
|
|
|
|
|
|
|
|
|
|
std::once_flag cublas_dso_flag;
|
|
|
|
|
void *cublas_dso_handle = nullptr;
|
|
|
|
@ -32,17 +32,17 @@ void *cublas_dso_handle = nullptr;
|
|
|
|
|
* note: default dynamic linked libs
|
|
|
|
|
*/
|
|
|
|
|
#ifdef PADDLE_USE_DSO
|
|
|
|
|
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
|
|
|
|
|
struct DynLoad__##__name { \
|
|
|
|
|
template <typename... Args> \
|
|
|
|
|
cublasStatus_t operator()(Args... args) { \
|
|
|
|
|
typedef cublasStatus_t (*cublasFunc)(Args...); \
|
|
|
|
|
std::call_once(cublas_dso_flag, \
|
|
|
|
|
paddle::platform::dyload::GetCublasDsoHandle, \
|
|
|
|
|
&cublas_dso_handle); \
|
|
|
|
|
void *p_##__name = dlsym(cublas_dso_handle, #__name); \
|
|
|
|
|
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
|
|
|
|
|
} \
|
|
|
|
|
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
|
|
|
|
|
struct DynLoad__##__name { \
|
|
|
|
|
template <typename... Args> \
|
|
|
|
|
cublasStatus_t operator()(Args... args) { \
|
|
|
|
|
typedef cublasStatus_t (*cublasFunc)(Args...); \
|
|
|
|
|
std::call_once(cublas_dso_flag, \
|
|
|
|
|
paddle::platform::dynload::GetCublasDsoHandle, \
|
|
|
|
|
&cublas_dso_handle); \
|
|
|
|
|
void *p_##__name = dlsym(cublas_dso_handle, #__name); \
|
|
|
|
|
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
|
|
|
|
|
} \
|
|
|
|
|
} __name; // struct DynLoad__##__name
|
|
|
|
|
#else
|
|
|
|
|
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
|
|
|
|
@ -99,6 +99,6 @@ CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)
|
|
|
|
|
#define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched
|
|
|
|
|
#define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched
|
|
|
|
|
#endif
|
|
|
|
|
} // namespace dyload
|
|
|
|
|
} // namespace dynload
|
|
|
|
|
} // namespace platform
|
|
|
|
|
} // namespace paddle
|