|
|
|
@ -143,6 +143,39 @@ class CudnnWorkspaceHandle {
|
|
|
|
|
std::unique_ptr<std::lock_guard<std::mutex>> guard_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9000
|
|
|
|
|
class ScopedCublasMathMode {
|
|
|
|
|
public:
|
|
|
|
|
ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode)
|
|
|
|
|
: handle_(handle) {
|
|
|
|
|
need_reset = false;
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::cublasGetMathMode(handle_, &old_math_mode_),
|
|
|
|
|
"Failed to get old cublas math mode");
|
|
|
|
|
if (old_math_mode_ != new_math_mode) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::cublasSetMathMode(handle_, new_math_mode),
|
|
|
|
|
"Failed to set old cublas math mode");
|
|
|
|
|
need_reset = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~ScopedCublasMathMode() {
|
|
|
|
|
if (need_reset) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::cublasSetMathMode(handle_, old_math_mode_),
|
|
|
|
|
"Failed to set old cublas math mode");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
cublasHandle_t handle_;
|
|
|
|
|
cublasMath_t old_math_mode_;
|
|
|
|
|
bool need_reset;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
public:
|
|
|
|
|
explicit CUDADeviceContext(CUDAPlace place);
|
|
|
|
@ -199,6 +232,18 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
callback_manager_->Wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9000
|
|
|
|
|
/*! \brief CublasCall may need to change cublas's config,
|
|
|
|
|
* but the cublas may be hold by multi-thread, so we should
|
|
|
|
|
* add lock here. */
|
|
|
|
|
template <typename Callback>
|
|
|
|
|
void CublasCall(Callback callback, cublasMath_t new_math) {
|
|
|
|
|
std::lock_guard<std::mutex> guard(cublas_mtx_);
|
|
|
|
|
ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math);
|
|
|
|
|
callback();
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
CUDAPlace place_;
|
|
|
|
|
|
|
|
|
@ -220,6 +265,8 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
// If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes
|
|
|
|
|
mutable std::mutex callback_mtx_;
|
|
|
|
|
std::unique_ptr<StreamCallbackManager> callback_manager_;
|
|
|
|
|
|
|
|
|
|
mutable std::mutex cublas_mtx_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|