|
|
|
@ -261,11 +261,7 @@ void hl_vector_sum(real *A_d, real *C_h, int dimM) {
|
|
|
|
|
|
|
|
|
|
struct _hl_event_st hl_event_st = {.cu_event = t_resource.event};
|
|
|
|
|
hl_event_t hl_event = &hl_event_st;
|
|
|
|
|
|
|
|
|
|
bool isNotReady = false;
|
|
|
|
|
do {
|
|
|
|
|
hl_cuda_event_query(hl_event, isNotReady);
|
|
|
|
|
} while (isNotReady == cudaErrorNotReady);
|
|
|
|
|
while (!hl_cuda_event_is_ready(hl_event)) {}
|
|
|
|
|
|
|
|
|
|
KeVectorSum<128><<< grid, threads, 0, STREAM_DEFAULT >>>
|
|
|
|
|
(A_d, t_resource.gpu_mem, dimM);
|
|
|
|
@ -275,7 +271,10 @@ void hl_vector_sum(real *A_d, real *C_h, int dimM) {
|
|
|
|
|
hl_memcpy_async(C_h, t_resource.cpu_mem, sizeof(real), HPPL_STREAM_DEFAULT);
|
|
|
|
|
hl_stream_record_event(HPPL_STREAM_DEFAULT, hl_event);
|
|
|
|
|
|
|
|
|
|
CHECK_SYNC("hl_vector_sum failed");
|
|
|
|
|
hl_stream_synchronize(HPPL_STREAM_DEFAULT);
|
|
|
|
|
cudaError_t err = (cudaError_t)hl_get_device_last_error();
|
|
|
|
|
CHECK_EQ(cudaSuccess, err)
|
|
|
|
|
<< "CUDA error: " << hl_get_device_error_string((size_t)err);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <int blockSize>
|
|
|
|
@ -317,11 +316,7 @@ void hl_vector_abs_sum(real *A_d, real *C_h, int dimM) {
|
|
|
|
|
|
|
|
|
|
struct _hl_event_st hl_event_st = {.cu_event = t_resource.event};
|
|
|
|
|
hl_event_t hl_event = &hl_event_st;
|
|
|
|
|
|
|
|
|
|
bool isNotReady = false;
|
|
|
|
|
do {
|
|
|
|
|
hl_cuda_event_query(hl_event, isNotReady);
|
|
|
|
|
} while (isNotReady == cudaErrorNotReady);
|
|
|
|
|
while (!hl_cuda_event_is_ready(hl_event)) {}
|
|
|
|
|
|
|
|
|
|
KeVectorAbsSum<128><<< grid, threads, 0, STREAM_DEFAULT >>>
|
|
|
|
|
(A_d, t_resource.gpu_mem, dimM);
|
|
|
|
@ -331,5 +326,8 @@ void hl_vector_abs_sum(real *A_d, real *C_h, int dimM) {
|
|
|
|
|
hl_memcpy_async(C_h, t_resource.cpu_mem, sizeof(real), HPPL_STREAM_DEFAULT);
|
|
|
|
|
hl_stream_record_event(HPPL_STREAM_DEFAULT, hl_event);
|
|
|
|
|
|
|
|
|
|
CHECK_SYNC("hl_vector_abs_sum failed");
|
|
|
|
|
hl_stream_synchronize(HPPL_STREAM_DEFAULT);
|
|
|
|
|
cudaError_t err = (cudaError_t)hl_get_device_last_error();
|
|
|
|
|
CHECK_EQ(cudaSuccess, err)
|
|
|
|
|
<< "CUDA error: " << hl_get_device_error_string((size_t)err);
|
|
|
|
|
}
|
|
|
|
|