|
|
|
@ -224,7 +224,8 @@ void batched_gemm<platform::CPUDeviceContext, float>(
|
|
|
|
|
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
|
|
|
|
|
const float alpha, const float* A, const float* B, const float beta,
|
|
|
|
|
float* C, const int batchCount, const int strideA, const int strideB) {
|
|
|
|
|
float* C, const int batchCount, const int64_t strideA,
|
|
|
|
|
const int64_t strideB) {
|
|
|
|
|
for (int k = 0; k < batchCount; ++k) {
|
|
|
|
|
const float* Ak = &A[k * strideA];
|
|
|
|
|
const float* Bk = &B[k * strideB];
|
|
|
|
@ -239,7 +240,8 @@ void batched_gemm<platform::CPUDeviceContext, double>(
|
|
|
|
|
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
|
|
|
|
|
const double alpha, const double* A, const double* B, const double beta,
|
|
|
|
|
double* C, const int batchCount, const int strideA, const int strideB) {
|
|
|
|
|
double* C, const int batchCount, const int64_t strideA,
|
|
|
|
|
const int64_t strideB) {
|
|
|
|
|
for (int k = 0; k < batchCount; ++k) {
|
|
|
|
|
const double* Ak = &A[k * strideA];
|
|
|
|
|
const double* Bk = &B[k * strideB];
|
|
|
|
|