|
|
|
@ -25,8 +25,8 @@ void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const float alpha, const float* A,
|
|
|
|
|
const float* B, const float beta, float* C,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
int lda = K;
|
|
|
|
|
int ldb = N;
|
|
|
|
|
int lda = (transA == CblasNoTrans) ? K : M;
|
|
|
|
|
int ldb = (transB == CblasNoTrans) ? N : K;
|
|
|
|
|
int ldc = N;
|
|
|
|
|
cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
|
|
|
|
|
beta, C, ldc);
|
|
|
|
@ -40,8 +40,8 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
|
|
|
|
|
const double* B, const double beta,
|
|
|
|
|
double* C,
|
|
|
|
|
platform::DeviceContext* context) {
|
|
|
|
|
int lda = K;
|
|
|
|
|
int ldb = N;
|
|
|
|
|
int lda = (transA == CblasNoTrans) ? K : M;
|
|
|
|
|
int ldb = (transB == CblasNoTrans) ? N : K;
|
|
|
|
|
int ldc = N;
|
|
|
|
|
cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
|
|
|
|
|
beta, C, ldc);
|
|
|
|
|