|
|
|
@ -15,7 +15,7 @@ limitations under the License. */
|
|
|
|
|
#include "MathFunctions.h"
|
|
|
|
|
#include "hl_matrix_apply.cuh"
|
|
|
|
|
#include "hl_matrix_ops.cuh"
|
|
|
|
|
#include "paddle/utils/DynamicLoad.h"
|
|
|
|
|
#include "paddle/utils/DynamicLoader.h"
|
|
|
|
|
|
|
|
|
|
namespace dynload {
|
|
|
|
|
|
|
|
|
@ -32,7 +32,7 @@ void* lapack_dso_handle = nullptr;
|
|
|
|
|
#define DYNAMIC_LOAD_LAPACK_WRAP(__name) \
|
|
|
|
|
struct DynLoad__##__name { \
|
|
|
|
|
template <typename... Args> \
|
|
|
|
|
auto operator()(Args... args)->decltype(__name(args...)) { \
|
|
|
|
|
auto operator()(Args... args) -> decltype(__name(args...)) { \
|
|
|
|
|
using lapack_func = decltype(__name(args...)) (*)(Args...); \
|
|
|
|
|
std::call_once(lapack_dso_flag, GetLapackDsoHandle, &lapack_dso_handle); \
|
|
|
|
|
void* p_##__name = dlsym(lapack_dso_handle, #__name); \
|
|
|
|
@ -41,24 +41,27 @@ void* lapack_dso_handle = nullptr;
|
|
|
|
|
} __name; // struct DynLoad__##__name
|
|
|
|
|
|
|
|
|
|
// clang-format off
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
#define LAPACK_ROUTINE_EACH(__macro) \
|
|
|
|
|
__macro(clapack_sgetrf) \
|
|
|
|
|
__macro(clapack_dgetrf) \
|
|
|
|
|
__macro(clapack_sgetri) \
|
|
|
|
|
__macro(clapack_dgetri)
|
|
|
|
|
#define PADDLE_SGETRF clapack_sgetrf
|
|
|
|
|
#define PADDLE_DGETRF clapack_dgetrf
|
|
|
|
|
#define PADDLE_SGETRI clapack_sgetri
|
|
|
|
|
#define PADDLE_DGETRI clapack_dgetri
|
|
|
|
|
#else
|
|
|
|
|
#define LAPACK_ROUTINE_EACH(__macro) \
|
|
|
|
|
__macro(LAPACKE_sgetrf) \
|
|
|
|
|
__macro(LAPACKE_dgetrf) \
|
|
|
|
|
__macro(LAPACKE_sgetri) \
|
|
|
|
|
__macro(LAPACKE_dgetri)
|
|
|
|
|
#endif
|
|
|
|
|
#define PADDLE_SGETRF LAPACKE_sgetrf
|
|
|
|
|
#define PADDLE_DGETRF LAPACKE_dgetrf
|
|
|
|
|
#define PADDLE_SGETRI LAPACKE_sgetri
|
|
|
|
|
#define PADDLE_DGETRI LAPACKE_dgetri
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define LAPACK_ROUTINE_EACH(__macro) \
|
|
|
|
|
__macro(PADDLE_SGETRF) \
|
|
|
|
|
__macro(PADDLE_DGETRF) \
|
|
|
|
|
__macro(PADDLE_SGETRI) \
|
|
|
|
|
__macro(PADDLE_DGETRI)
|
|
|
|
|
// clang-format on
|
|
|
|
|
|
|
|
|
|
LAPACK_ROUTINE_EACH(DYNAMIC_LOAD_LAPACK_WRAP)
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
// clang-format on
|
|
|
|
|
} // namespace dynload
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -130,16 +133,7 @@ int getrf<float>(const CBLAS_ORDER order,
|
|
|
|
|
float* A,
|
|
|
|
|
const int lda,
|
|
|
|
|
int* ipiv) {
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
return dynload::clapack_sgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
#else
|
|
|
|
|
return dynload::LAPACKE_sgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
LOG(FATAL) << "Not implemented";
|
|
|
|
|
#endif
|
|
|
|
|
return 0;
|
|
|
|
|
return dynload::PADDLE_SGETRF(order, M, N, A, lda, ipiv);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -149,16 +143,7 @@ int getrf<double>(const CBLAS_ORDER order,
|
|
|
|
|
double* A,
|
|
|
|
|
const int lda,
|
|
|
|
|
int* ipiv) {
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
return dynload::clapack_dgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
#else
|
|
|
|
|
return dynload::LAPACKE_dgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
LOG(FATAL) << "Not implemented";
|
|
|
|
|
#endif
|
|
|
|
|
return 0;
|
|
|
|
|
return dynload::PADDLE_DGETRF(order, M, N, A, lda, ipiv);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -167,16 +152,7 @@ int getri<float>(const CBLAS_ORDER order,
|
|
|
|
|
float* A,
|
|
|
|
|
const int lda,
|
|
|
|
|
const int* ipiv) {
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
return dynload::clapack_sgetri(order, N, A, lda, ipiv);
|
|
|
|
|
#else
|
|
|
|
|
return dynload::LAPACKE_sgetri(order, N, A, lda, ipiv);
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
LOG(FATAL) << "Not implemented";
|
|
|
|
|
#endif
|
|
|
|
|
return 0;
|
|
|
|
|
return dynload::PADDLE_SGETRI(order, N, A, lda, ipiv);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -185,15 +161,7 @@ int getri<double>(const CBLAS_ORDER order,
|
|
|
|
|
double* A,
|
|
|
|
|
const int lda,
|
|
|
|
|
const int* ipiv) {
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
return dynload::clapack_dgetri(order, N, A, lda, ipiv);
|
|
|
|
|
#else
|
|
|
|
|
return dynload::LAPACKE_dgetri(order, N, A, lda, ipiv);
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
LOG(FATAL) << "Not implemented";
|
|
|
|
|
#endif
|
|
|
|
|
return dynload::PADDLE_DGETRI(order, N, A, lda, ipiv);
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|