|
|
|
@ -24,17 +24,23 @@ struct CBlas;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct CBlas<float> {
|
|
|
|
|
static constexpr auto GEMM = cblas_sgemm;
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM(ARGS... args) {
|
|
|
|
|
cblas_sgemm(args...);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct CBlas<double> {
|
|
|
|
|
static constexpr auto GEMM = cblas_dgemm;
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
static void GEMM(ARGS... args) {
|
|
|
|
|
cblas_dgemm(args...);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct CBlas<platform::float16> {
|
|
|
|
|
void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
|
|
|
|
|
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|