Merge pull request #10569 from reyoung/feature/matmul_support_float16_double

matmul support float16/double
fix_gru_py
Yu Yang 7 years ago committed by GitHub
commit 2924c92a2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -96,10 +96,22 @@ struct CUBlas<platform::float16> {
reinterpret_cast<__half *>(C), ldc)); reinterpret_cast<__half *>(C), ldc));
} }
template <typename... ARGS> static void GEMM_BATCH(cublasHandle_t handle, cublasOperation_t transa,
static void GEMM_BATCH(ARGS... args) { cublasOperation_t transb, int m, int n, int k,
const float16 *alpha, const float16 *A, int lda,
long long int strideA, const float16 *B, // NOLINT
int ldb, long long int strideB, // NOLINT
const float16 *beta, float16 *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(args...)); PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
handle, transa, transb, m, n, k,
reinterpret_cast<const __half *>(alpha),
reinterpret_cast<const __half *>(A), lda, strideA,
reinterpret_cast<const __half *>(B), ldb, strideB,
reinterpret_cast<const __half *>(beta), reinterpret_cast<__half *>(C),
ldc, strideC, batchCount));
#else #else
PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5"); PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5");
#endif #endif

@ -172,9 +172,9 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
c_array.data(), &ldc, 1 /* group_count */, &batchCount); c_array.data(), &ldc, 1 /* group_count */, &batchCount);
#else #else
for (int k = 0; k < batchCount; ++k) { for (int k = 0; k < batchCount; ++k) {
const float *Ak = &A[k * strideA]; auto *Ak = &A[k * strideA];
const float *Bk = &B[k * strideB]; auto *Bk = &B[k * strideB];
float *Ck = &C[k * M * N]; auto *Ck = &C[k * M * N];
this->template GEMM<T>(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck); this->template GEMM<T>(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck);
} }
#endif #endif

@ -35,7 +35,8 @@ template struct SetConstant<platform::CUDADeviceContext, bool>;
#define DEFINE_GPU_TRANS(RANK) \ #define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \ template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>;
DEFINE_GPU_TRANS(1); DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2); DEFINE_GPU_TRANS(2);

@ -420,15 +420,27 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
ops::MatMulOpGradMaker); ops::MatMulOpGradMaker);
REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad); REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>); matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
matmul_grad, matmul_grad,
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>); ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>); matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>,
ops::MatMulKernel<paddle::platform::CUDADeviceContext, double>,
ops::MatMulKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
matmul_grad, matmul_grad,
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float>); ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
#endif #endif

Loading…
Cancel
Save