|
|
|
|
@ -102,7 +102,7 @@ class Blas {
|
|
|
|
|
T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C,
|
|
|
|
|
int ldc) const;
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML // @{ Group MKLML: class Blas
|
|
|
|
|
template <typename T>
|
|
|
|
|
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
|
|
|
|
|
const int K) const;
|
|
|
|
|
@ -126,7 +126,7 @@ class Blas {
|
|
|
|
|
const int* indx, const int* pntrb, const int* pntre, const T* b,
|
|
|
|
|
const int* ldb, const T* beta, T* c, const int* ldc) const;
|
|
|
|
|
|
|
|
|
|
#if !defined(PADDLE_WITH_CUDA)
|
|
|
|
|
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
|
|
|
|
|
template <typename T>
|
|
|
|
|
void MatMulWithHead(const framework::Tensor& mat_a,
|
|
|
|
|
const MatDescriptor& dim_a,
|
|
|
|
|
@ -135,7 +135,7 @@ class Blas {
|
|
|
|
|
framework::Tensor* mat_out, T beta,
|
|
|
|
|
bool mat_y_split_vertical) const;
|
|
|
|
|
#endif
|
|
|
|
|
#endif
|
|
|
|
|
#endif // @} End Group MKLML: class Blas
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void MatMul(const int M, const int N, const int K, const T* A, const T* B,
|
|
|
|
|
@ -210,7 +210,8 @@ class Blas {
|
|
|
|
|
int K, T alpha, const T** A, const T** B, T beta, T** C,
|
|
|
|
|
int batchCount) const;
|
|
|
|
|
|
|
|
|
|
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
|
|
|
|
|
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
|
|
|
|
|
!defined(PADDLE_WITH_HIP)
|
|
|
|
|
template <typename T>
|
|
|
|
|
void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
|
|
|
|
|
int W1, int H1, int W2, int H2, T alpha, const T* A,
|
|
|
|
|
@ -235,7 +236,7 @@ class Blas {
|
|
|
|
|
CBLAS_DIAG diag, int M, int N, T alpha, const T* A, int lda, T* B,
|
|
|
|
|
int ldb) const;
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
|
|
|
|
|
template <typename T>
|
|
|
|
|
void BatchedGETRF(int n, T** a, int* ipiv, int* info, int batch_size) const;
|
|
|
|
|
|
|
|
|
|
@ -262,7 +263,7 @@ class BlasT : private Blas<DeviceContext> {
|
|
|
|
|
Base()->template GEMM<T>(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
|
#ifdef PADDLE_WITH_MKLML // @{ Group MKLML: class BlasT
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
T* GEMM_ALLOC(ARGS... args) const {
|
|
|
|
|
return Base()->template GEMM_ALLOC<T>(args...);
|
|
|
|
|
@ -288,13 +289,13 @@ class BlasT : private Blas<DeviceContext> {
|
|
|
|
|
Base()->template CSRMM<T>(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if !defined(PADDLE_WITH_CUDA)
|
|
|
|
|
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
void MatMulWithHead(ARGS... args) const {
|
|
|
|
|
Base()->template MatMulWithHead<T>(args...);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
#endif
|
|
|
|
|
#endif // @} End Group MKLML: class BlasT
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
void MatMul(ARGS... args) const {
|
|
|
|
|
@ -386,7 +387,7 @@ class BlasT : private Blas<DeviceContext> {
|
|
|
|
|
Base()->template TRSM<T>(args...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
void BatchedGETRF(ARGS... args) const {
|
|
|
|
|
Base()->template BatchedGETRF<T>(args...);
|
|
|
|
|
@ -429,3 +430,6 @@ inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#include "paddle/fluid/operators/math/blas_impl.cu.h"
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef PADDLE_WITH_HIP
|
|
|
|
|
#include "paddle/fluid/operators/math/blas_impl.hip.h"
|
|
|
|
|
#endif
|
|
|
|
|
|