|
|
@ -314,14 +314,23 @@ CUDADeviceContext::~CUDADeviceContext() {
|
|
|
|
Place CUDADeviceContext::GetPlace() const { return place_; }
|
|
|
|
Place CUDADeviceContext::GetPlace() const { return place_; }
|
|
|
|
|
|
|
|
|
|
|
|
void CUDADeviceContext::Wait() const {
|
|
|
|
void CUDADeviceContext::Wait() const {
|
|
|
|
cudaError_t e_sync = cudaStreamSynchronize(stream_);
|
|
|
|
cudaError_t e_sync = cudaSuccess;
|
|
|
|
if (e_sync != 0) {
|
|
|
|
#if !defined(_WIN32)
|
|
|
|
|
|
|
|
e_sync = cudaStreamSynchronize(stream_);
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
|
|
|
while (e_sync = cudaStreamQuery(stream_)) {
|
|
|
|
|
|
|
|
if (e_sync == cudaErrorNotReady) continue;
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (cudaSuccess != e_sync) {
|
|
|
|
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync)
|
|
|
|
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync)
|
|
|
|
<< " errno: " << e_sync;
|
|
|
|
<< " errno: " << e_sync;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
cudaError_t e_get = cudaGetLastError();
|
|
|
|
cudaError_t e_get = cudaGetLastError();
|
|
|
|
if (e_get != 0) {
|
|
|
|
if (cudaSuccess != e_get) {
|
|
|
|
LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get)
|
|
|
|
LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get)
|
|
|
|
<< " errno: " << e_get;
|
|
|
|
<< " errno: " << e_get;
|
|
|
|
}
|
|
|
|
}
|
|
|
|