|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/data_type.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function_impl.h"
|
|
|
|
|
#include "paddle/fluid/platform/float16.h"
|
|
|
|
@ -161,7 +162,8 @@ void batched_gemm<platform::CPUDeviceContext, float16>(
|
|
|
|
|
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
|
|
|
|
|
const float16 alpha, const float16* A, const float16* B, const float16 beta,
|
|
|
|
|
float16* C, const int batchCount, const int strideA, const int strideB) {
|
|
|
|
|
float16* C, const int batchCount, const int64_t strideA,
|
|
|
|
|
const int64_t strideB) {
|
|
|
|
|
PADDLE_THROW("float16 batched_gemm not supported on CPU");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -172,7 +174,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
|
|
|
|
|
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
|
|
|
|
|
const float alpha, const float* A, const float* B, const float beta,
|
|
|
|
|
float* C, const int batchCount, const int strideA, const int strideB) {
|
|
|
|
|
float* C, const int batchCount, const int64_t strideA,
|
|
|
|
|
const int64_t strideB) {
|
|
|
|
|
int lda = (transA == CblasNoTrans) ? K : M;
|
|
|
|
|
int ldb = (transB == CblasNoTrans) ? N : K;
|
|
|
|
|
int ldc = N;
|
|
|
|
@ -194,7 +197,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
|
|
|
|
|
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
|
|
|
|
|
const double alpha, const double* A, const double* B, const double beta,
|
|
|
|
|
double* C, const int batchCount, const int strideA, const int strideB) {
|
|
|
|
|
double* C, const int batchCount, const int64_t strideA,
|
|
|
|
|
const int64_t strideB) {
|
|
|
|
|
int lda = (transA == CblasNoTrans) ? K : M;
|
|
|
|
|
int ldb = (transB == CblasNoTrans) ? N : K;
|
|
|
|
|
int ldc = N;
|
|
|
|
@ -220,7 +224,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
|
|
|
|
|
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
|
|
|
|
|
const float alpha, const float* A, const float* B, const float beta,
|
|
|
|
|
float* C, const int batchCount, const int strideA, const int strideB) {
|
|
|
|
|
float* C, const int batchCount, const int64_t strideA,
|
|
|
|
|
const int64_t strideB) {
|
|
|
|
|
for (int k = 0; k < batchCount; ++k) {
|
|
|
|
|
const float* Ak = &A[k * strideA];
|
|
|
|
|
const float* Bk = &B[k * strideB];
|
|
|
|
@ -235,7 +240,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
|
|
|
|
|
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
|
|
|
|
|
const double alpha, const double* A, const double* B, const double beta,
|
|
|
|
|
double* C, const int batchCount, const int strideA, const int strideB) {
|
|
|
|
|
double* C, const int batchCount, const int64_t strideA,
|
|
|
|
|
const int64_t strideB) {
|
|
|
|
|
for (int k = 0; k < batchCount; ++k) {
|
|
|
|
|
const double* Ak = &A[k * strideA];
|
|
|
|
|
const double* Bk = &B[k * strideB];
|
|
|
|
|