|
|
|
@ -92,26 +92,24 @@ platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
|
|
|
|
|
const platform::Place& place, const cudaStream_t& stream) {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(place));
|
|
|
|
|
auto place_stream = std::make_pair(place, stream);
|
|
|
|
|
{
|
|
|
|
|
std::unique_lock<std::mutex> lock(mtx_);
|
|
|
|
|
if (!device_allocator_.count(place_stream)) {
|
|
|
|
|
device_allocator_[place_stream].reset(new TemporaryAllocator(place));
|
|
|
|
|
device_allocator_[place_stream]->SetCallback([stream]() {
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
|
|
|
|
PADDLE_ENFORCE(cudaGetLastError());
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
std::unique_lock<std::mutex> lock(mtx_);
|
|
|
|
|
auto it = device_allocator_.find(place_stream);
|
|
|
|
|
if (it == device_allocator_.end()) {
|
|
|
|
|
auto tmp_allocator = new TemporaryAllocator(place);
|
|
|
|
|
tmp_allocator->SetCallback([stream]() {
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
|
|
|
|
PADDLE_ENFORCE(cudaGetLastError());
|
|
|
|
|
});
|
|
|
|
|
device_allocator_[place_stream].reset(tmp_allocator);
|
|
|
|
|
return *tmp_allocator;
|
|
|
|
|
} else {
|
|
|
|
|
return *it->second;
|
|
|
|
|
}
|
|
|
|
|
return *device_allocator_.at(place_stream);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
|
|
|
|
|
const platform::CUDADeviceContext& dev_ctx) {
|
|
|
|
|
auto place_stream = std::make_pair(dev_ctx.GetPlace(), dev_ctx.stream());
|
|
|
|
|
if (device_allocator_.count(place_stream)) {
|
|
|
|
|
return *device_allocator_.at(place_stream);
|
|
|
|
|
}
|
|
|
|
|
return Get(dev_ctx.GetPlace(), dev_ctx.stream());
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
@ -325,7 +323,7 @@ Place CUDADeviceContext::GetPlace() const { return place_; }
|
|
|
|
|
void CUDADeviceContext::Wait() const {
|
|
|
|
|
auto& allocator =
|
|
|
|
|
DeviceTemporaryAllocator::Instance().Get<CUDADeviceContext>(*this);
|
|
|
|
|
allocator.Release([=]() {
|
|
|
|
|
allocator.Release([this]() {
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
|
|
|
|
|
PADDLE_ENFORCE(cudaGetLastError());
|
|
|
|
|
});
|
|
|
|
|