From d400b4192dec93c7d1f1c92867b92dd4425b2eb6 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Fri, 9 Mar 2018 15:36:47 -0800 Subject: [PATCH] fix math function arch mismatch for older GPU --- .../operators/math/math_function_test.cu | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/paddle/fluid/operators/math/math_function_test.cu b/paddle/fluid/operators/math/math_function_test.cu index 442e62d563..4562853086 100644 --- a/paddle/fluid/operators/math/math_function_test.cu +++ b/paddle/fluid/operators/math/math_function_test.cu @@ -14,6 +14,8 @@ #include "gtest/gtest.h" #include "paddle/fluid/operators/math/math_function.h" +#include + void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, const std::vector& data) { PADDLE_ENFORCE_EQ(size, data.size()); @@ -22,6 +24,15 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, } } +bool is_fp16_supported(int device_id) { + cudaDeviceProp device_prop; + cudaDeviceProperties(&device_prop, device_id); + PADDLE_ENFORCE_EQ(cudaGetLastError(), cudaSuccess); + int compute_capability = device_prop.major * 10 + device_prop.minor; + std::cout << "compute_capability is " << compute_capability << std::endl; + return compute_capability >= 53; +} + TEST(math_function, notrans_mul_trans_fp32) { using namespace paddle::framework; using namespace paddle::platform; @@ -62,6 +73,10 @@ TEST(math_function, notrans_mul_trans_fp16) { using namespace paddle::framework; using namespace paddle::platform; + if (!is_fp16_supported(0)) { + return; + } + Tensor input1; Tensor input1_gpu; Tensor input2_gpu; @@ -139,6 +154,10 @@ TEST(math_function, trans_mul_notrans_fp16) { using namespace paddle::framework; using namespace paddle::platform; + if (!is_fp16_supported(0)) { + return; + } + Tensor input1; Tensor input1_gpu; Tensor input2_gpu; @@ -237,6 +256,10 @@ TEST(math_function, gemm_notrans_cublas_fp16) { using namespace paddle::framework; using namespace paddle::platform; + if (!is_fp16_supported(0)) { + return; + } + Tensor input1; Tensor input2; Tensor input3; @@ -344,6 +367,10 @@ TEST(math_function, gemm_trans_cublas_fp16) { using namespace paddle::framework; using namespace paddle::platform; + if (!is_fp16_supported(0)) { + return; + } + Tensor input1; Tensor input2; Tensor input3;