|
|
|
@ -15,6 +15,54 @@ limitations under the License. */
|
|
|
|
|
#include "MathFunctions.h"
|
|
|
|
|
#include "hl_matrix_apply.cuh"
|
|
|
|
|
#include "hl_matrix_ops.cuh"
|
|
|
|
|
#include "paddle/utils/DynamicLoader.h"
|
|
|
|
|
|
|
|
|
|
namespace dynload {
|
|
|
|
|
|
|
|
|
|
std::once_flag lapack_dso_flag;
|
|
|
|
|
void* lapack_dso_handle = nullptr;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* The following macro definition can generate structs
|
|
|
|
|
* (for each function) to dynamic load lapack routine
|
|
|
|
|
* via operator overloading.
|
|
|
|
|
*
|
|
|
|
|
* note: default dynamic linked libs
|
|
|
|
|
*/
|
|
|
|
|
#define DYNAMIC_LOAD_LAPACK_WRAP(__name) \
|
|
|
|
|
struct DynLoad__##__name { \
|
|
|
|
|
template <typename... Args> \
|
|
|
|
|
auto operator()(Args... args) -> decltype(__name(args...)) { \
|
|
|
|
|
using lapack_func = decltype(__name(args...)) (*)(Args...); \
|
|
|
|
|
std::call_once(lapack_dso_flag, GetLapackDsoHandle, &lapack_dso_handle); \
|
|
|
|
|
void* p_##__name = dlsym(lapack_dso_handle, #__name); \
|
|
|
|
|
return reinterpret_cast<lapack_func>(p_##__name)(args...); \
|
|
|
|
|
} \
|
|
|
|
|
} __name; // struct DynLoad__##__name
|
|
|
|
|
|
|
|
|
|
// clang-format off
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
#define PADDLE_SGETRF clapack_sgetrf
|
|
|
|
|
#define PADDLE_DGETRF clapack_dgetrf
|
|
|
|
|
#define PADDLE_SGETRI clapack_sgetri
|
|
|
|
|
#define PADDLE_DGETRI clapack_dgetri
|
|
|
|
|
#else
|
|
|
|
|
#define PADDLE_SGETRF LAPACKE_sgetrf
|
|
|
|
|
#define PADDLE_DGETRF LAPACKE_dgetrf
|
|
|
|
|
#define PADDLE_SGETRI LAPACKE_sgetri
|
|
|
|
|
#define PADDLE_DGETRI LAPACKE_dgetri
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define LAPACK_ROUTINE_EACH(__macro) \
|
|
|
|
|
__macro(PADDLE_SGETRF) \
|
|
|
|
|
__macro(PADDLE_DGETRF) \
|
|
|
|
|
__macro(PADDLE_SGETRI) \
|
|
|
|
|
__macro(PADDLE_DGETRI)
|
|
|
|
|
// clang-format on
|
|
|
|
|
|
|
|
|
|
LAPACK_ROUTINE_EACH(DYNAMIC_LOAD_LAPACK_WRAP)
|
|
|
|
|
|
|
|
|
|
} // namespace dynload
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
@ -85,16 +133,7 @@ int getrf<float>(const CBLAS_ORDER order,
|
|
|
|
|
float* A,
|
|
|
|
|
const int lda,
|
|
|
|
|
int* ipiv) {
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
return clapack_sgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
#else
|
|
|
|
|
return LAPACKE_sgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
LOG(FATAL) << "Not implemented";
|
|
|
|
|
#endif
|
|
|
|
|
return 0;
|
|
|
|
|
return dynload::PADDLE_SGETRF(order, M, N, A, lda, ipiv);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -104,16 +143,7 @@ int getrf<double>(const CBLAS_ORDER order,
|
|
|
|
|
double* A,
|
|
|
|
|
const int lda,
|
|
|
|
|
int* ipiv) {
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
return clapack_dgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
#else
|
|
|
|
|
return LAPACKE_dgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
LOG(FATAL) << "Not implemented";
|
|
|
|
|
#endif
|
|
|
|
|
return 0;
|
|
|
|
|
return dynload::PADDLE_DGETRF(order, M, N, A, lda, ipiv);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -122,16 +152,7 @@ int getri<float>(const CBLAS_ORDER order,
|
|
|
|
|
float* A,
|
|
|
|
|
const int lda,
|
|
|
|
|
const int* ipiv) {
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
return clapack_sgetri(order, N, A, lda, ipiv);
|
|
|
|
|
#else
|
|
|
|
|
return LAPACKE_sgetri(order, N, A, lda, ipiv);
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
LOG(FATAL) << "Not implemented";
|
|
|
|
|
#endif
|
|
|
|
|
return 0;
|
|
|
|
|
return dynload::PADDLE_SGETRI(order, N, A, lda, ipiv);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
@ -140,15 +161,7 @@ int getri<double>(const CBLAS_ORDER order,
|
|
|
|
|
double* A,
|
|
|
|
|
const int lda,
|
|
|
|
|
const int* ipiv) {
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
return clapack_dgetri(order, N, A, lda, ipiv);
|
|
|
|
|
#else
|
|
|
|
|
return LAPACKE_dgetri(order, N, A, lda, ipiv);
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
LOG(FATAL) << "Not implemented";
|
|
|
|
|
#endif
|
|
|
|
|
return dynload::PADDLE_DGETRI(order, N, A, lda, ipiv);
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|