|
|
|
@ -15,6 +15,49 @@ limitations under the License. */
|
|
|
|
|
#include "MathFunctions.h"
|
|
|
|
|
#include "hl_matrix_apply.cuh"
|
|
|
|
|
#include "hl_matrix_ops.cuh"
|
|
|
|
|
#include "paddle/utils/DynamicLoad.h"
|
|
|
|
|
|
|
|
|
|
namespace dynload {
|
|
|
|
|
|
|
|
|
|
std::once_flag lapack_dso_flag;
|
|
|
|
|
void* lapack_dso_handle = nullptr;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* The following macro definition can generate structs
|
|
|
|
|
* (for each function) to dynamic load lapack routine
|
|
|
|
|
* via operator overloading.
|
|
|
|
|
*
|
|
|
|
|
* note: default dynamic linked libs
|
|
|
|
|
*/
|
|
|
|
|
#define DYNAMIC_LOAD_LAPACK_WRAP(__name) \
|
|
|
|
|
struct DynLoad__##__name { \
|
|
|
|
|
template <typename... Args> \
|
|
|
|
|
int operator()(Args... args)->decltype(__name(args...)) { \
|
|
|
|
|
using lapack_func = decltype(__name(args...)) (*)(Args...); \
|
|
|
|
|
std::call_once(lapack_dso_flag, GetLapackDsoHandle, &lapack_dso_handle); \
|
|
|
|
|
void* p_##__name = dlsym(lapack_dso_handle, #__name); \
|
|
|
|
|
return reinterpret_cast<lapack_func>(p_##__name)(args...); \
|
|
|
|
|
} \
|
|
|
|
|
} __name; // struct DynLoad__##__name
|
|
|
|
|
|
|
|
|
|
// clang-format off
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
#define LAPACK_ROUTINE_EACH(__macro) \
|
|
|
|
|
__macro(clapack_sgetrf) \
|
|
|
|
|
__macro(clapack_dgetrf) \
|
|
|
|
|
__macro(clapack_sgetri) \
|
|
|
|
|
__macro(clapack_dgetri)
|
|
|
|
|
#else
|
|
|
|
|
#define LAPACK_ROUTINE_EACH(__macro) \
|
|
|
|
|
__macro(LAPACKE_sgetrf) \
|
|
|
|
|
__macro(LAPACKE_dgetrf) \
|
|
|
|
|
__macro(LAPACKE_sgetri) \
|
|
|
|
|
__macro(LAPACKE_dgetri)
|
|
|
|
|
#endif
|
|
|
|
|
#endif
|
|
|
|
|
// clang-format on
|
|
|
|
|
} // namespace dynload
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
@ -87,9 +130,9 @@ int getrf<float>(const CBLAS_ORDER order,
|
|
|
|
|
int* ipiv) {
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
return clapack_sgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
return dynload::clapack_sgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
#else
|
|
|
|
|
return LAPACKE_sgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
return dynload::LAPACKE_sgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
LOG(FATAL) << "Not implemented";
|
|
|
|
@ -106,9 +149,9 @@ int getrf<double>(const CBLAS_ORDER order,
|
|
|
|
|
int* ipiv) {
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
return clapack_dgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
return dynload::clapack_dgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
#else
|
|
|
|
|
return LAPACKE_dgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
return dynload::LAPACKE_dgetrf(order, M, N, A, lda, ipiv);
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
LOG(FATAL) << "Not implemented";
|
|
|
|
@ -124,9 +167,9 @@ int getri<float>(const CBLAS_ORDER order,
|
|
|
|
|
const int* ipiv) {
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
return clapack_sgetri(order, N, A, lda, ipiv);
|
|
|
|
|
return dynload::clapack_sgetri(order, N, A, lda, ipiv);
|
|
|
|
|
#else
|
|
|
|
|
return LAPACKE_sgetri(order, N, A, lda, ipiv);
|
|
|
|
|
return dynload::LAPACKE_sgetri(order, N, A, lda, ipiv);
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
LOG(FATAL) << "Not implemented";
|
|
|
|
@ -142,9 +185,9 @@ int getri<double>(const CBLAS_ORDER order,
|
|
|
|
|
const int* ipiv) {
|
|
|
|
|
#ifdef PADDLE_USE_LAPACK
|
|
|
|
|
#ifdef PADDLE_USE_ATLAS
|
|
|
|
|
return clapack_dgetri(order, N, A, lda, ipiv);
|
|
|
|
|
return dynload::clapack_dgetri(order, N, A, lda, ipiv);
|
|
|
|
|
#else
|
|
|
|
|
return LAPACKE_dgetri(order, N, A, lda, ipiv);
|
|
|
|
|
return dynload::LAPACKE_dgetri(order, N, A, lda, ipiv);
|
|
|
|
|
#endif
|
|
|
|
|
#else
|
|
|
|
|
LOG(FATAL) << "Not implemented";
|
|
|
|
|