Fix compile

trainerSaveLoadParams
Yu Yang 7 years ago
parent a6edeb39b3
commit bc8160350b

@ -279,8 +279,9 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
paddle::platform::float16* c =
input3_gpu.mutable_data<paddle::platform::float16>(gpu_place);
GetBlas<float16>(context).GEMM(false, false, m, n, k, float16(1), a, 3, b + 1,
4, float16(1), c + 1, 4);
GetBlas<paddle::platform::float16>(context).GEMM(
false, false, m, n, k, static_cast<paddle::platform::float16>(1), a, 3,
b + 1, 4, static_cast<paddle::platform::float16>(1), c + 1, 4);
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3);
@ -388,12 +389,9 @@ TEST(math_function, gemm_trans_cublas_fp16) {
paddle::platform::float16* c =
input3_gpu.mutable_data<paddle::platform::float16>(gpu_place);
GetBlas<float16>(context).GEMM(false, true, m, n, k, float16(1), a, 3, b + 3,
3, float16(1), c + 1, 4);
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext,
paddle::platform::float16>(
context, false, true, m, n, k, paddle::platform::float16(1), a, 3, b + 3,
3, paddle::platform::float16(1), c + 1, 4);
GetBlas<paddle::platform::float16>(context).GEMM(
false, true, m, n, k, static_cast<paddle::platform::float16>(1), a, 3,
b + 3, 3, static_cast<paddle::platform::float16>(1), c + 1, 4);
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3);

Loading…
Cancel
Save