|
|
|
@ -11,6 +11,7 @@ limitations under the License. */
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <mutex> // NOLINT
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <vector>
|
|
|
|
@ -100,6 +101,7 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
|
|
|
|
|
template <typename Callback>
|
|
|
|
|
void RecordEvent(cudaEvent_t ev, Callback callback) {
|
|
|
|
|
std::lock_guard<std::mutex> guard(mtx_);
|
|
|
|
|
callback();
|
|
|
|
|
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
|
|
|
|
|
}
|
|
|
|
@ -116,6 +118,8 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
int compute_capability;
|
|
|
|
|
int multi_process;
|
|
|
|
|
int max_threads_per_mp;
|
|
|
|
|
|
|
|
|
|
std::mutex mtx_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|