shanyi15-patch-2
Kexin Zhao 7 years ago
parent 1998d5afa2
commit 95de7617eb

@ -14,20 +14,11 @@
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include <iostream>
TEST(math_function, gemm_notrans_cblas) { TEST(math_function, gemm_notrans_cblas) {
paddle::framework::Tensor input1; paddle::framework::Tensor input1;
paddle::framework::Tensor input2; paddle::framework::Tensor input2;
paddle::framework::Tensor input3; paddle::framework::Tensor input3;
// fp16 GEMM in cublas requires GPU compute capability >= 53
if (GetCUDAComputeCapability(0) >= 53) {
std::cout << "Compute capability is " << GetCUDAComputeCapability(0)
<< std::endl;
return;
}
int m = 2; int m = 2;
int n = 3; int n = 3;
int k = 3; int k = 3;

@ -14,8 +14,6 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include <iostream>
void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
const std::vector<float>& data) { const std::vector<float>& data) {
PADDLE_ENFORCE_EQ(size, data.size()); PADDLE_ENFORCE_EQ(size, data.size());
@ -65,9 +63,7 @@ TEST(math_function, notrans_mul_trans_fp16) {
using namespace paddle::platform; using namespace paddle::platform;
// fp16 GEMM in cublas requires GPU compute capability >= 53 // fp16 GEMM in cublas requires GPU compute capability >= 53
if (GetCUDAComputeCapability(0) >= 53) { if (GetCUDAComputeCapability(0) < 53) {
std::cout << "Compute capability is " << GetCUDAComputeCapability(0)
<< std::endl;
return; return;
} }
@ -149,7 +145,7 @@ TEST(math_function, trans_mul_notrans_fp16) {
using namespace paddle::platform; using namespace paddle::platform;
// fp16 GEMM in cublas requires GPU compute capability >= 53 // fp16 GEMM in cublas requires GPU compute capability >= 53
if (GetCUDAComputeCapability(0) >= 53) { if (GetCUDAComputeCapability(0) < 53) {
return; return;
} }
@ -252,7 +248,7 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
using namespace paddle::platform; using namespace paddle::platform;
// fp16 GEMM in cublas requires GPU compute capability >= 53 // fp16 GEMM in cublas requires GPU compute capability >= 53
if (GetCUDAComputeCapability(0) >= 53) { if (GetCUDAComputeCapability(0) < 53) {
return; return;
} }
@ -364,7 +360,7 @@ TEST(math_function, gemm_trans_cublas_fp16) {
using namespace paddle::platform; using namespace paddle::platform;
// fp16 GEMM in cublas requires GPU compute capability >= 53 // fp16 GEMM in cublas requires GPU compute capability >= 53
if (GetCUDAComputeCapability(0) >= 53) { if (GetCUDAComputeCapability(0) < 53) {
return; return;
} }

Loading…
Cancel
Save