add cuda resource pool for BufferedReader, test=develop (#23152)
parent
07a1df8f50
commit
bba740710d
@ -0,0 +1,114 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#include "paddle/fluid/platform/cuda_resource_pool.h"
|
||||
#include "paddle/fluid/platform/gpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
|
||||
CudaStreamResourcePool::CudaStreamResourcePool() {
|
||||
int dev_cnt = platform::GetCUDADeviceCount();
|
||||
pool_.reserve(dev_cnt);
|
||||
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
|
||||
auto creator = [dev_idx] {
|
||||
platform::SetDeviceId(dev_idx);
|
||||
cudaStream_t stream;
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS(
|
||||
cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking),
|
||||
platform::errors::Fatal(
|
||||
"cudaStreamCreateWithFlags raises unexpected exception"));
|
||||
return stream;
|
||||
};
|
||||
|
||||
auto deleter = [dev_idx](cudaStream_t stream) {
|
||||
platform::SetDeviceId(dev_idx);
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS(
|
||||
cudaStreamDestroy(stream),
|
||||
platform::errors::Fatal(
|
||||
"cudaStreamDestroy raises unexpected exception"));
|
||||
};
|
||||
|
||||
pool_.emplace_back(
|
||||
ResourcePool<CudaStreamObject>::Create(creator, deleter));
|
||||
}
|
||||
}
|
||||
|
||||
CudaStreamResourcePool& CudaStreamResourcePool::Instance() {
|
||||
static CudaStreamResourcePool pool;
|
||||
return pool;
|
||||
}
|
||||
|
||||
std::shared_ptr<CudaStreamObject> CudaStreamResourcePool::New(int dev_idx) {
|
||||
PADDLE_ENFORCE_GE(
|
||||
dev_idx, 0,
|
||||
platform::errors::InvalidArgument(
|
||||
"dev_idx should be not less than 0, but got %d", dev_idx));
|
||||
PADDLE_ENFORCE_LT(
|
||||
dev_idx, pool_.size(),
|
||||
platform::errors::OutOfRange(
|
||||
"dev_idx should be less than device count %d, but got %d",
|
||||
pool_.size(), dev_idx));
|
||||
return pool_[dev_idx]->New();
|
||||
}
|
||||
|
||||
CudaEventResourcePool::CudaEventResourcePool() {
|
||||
int dev_cnt = platform::GetCUDADeviceCount();
|
||||
pool_.reserve(dev_cnt);
|
||||
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
|
||||
auto creator = [dev_idx] {
|
||||
platform::SetDeviceId(dev_idx);
|
||||
cudaEvent_t event;
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS(
|
||||
cudaEventCreateWithFlags(&event, cudaEventDisableTiming),
|
||||
platform::errors::Fatal(
|
||||
"cudaEventCreateWithFlags raises unexpected exception"));
|
||||
return event;
|
||||
};
|
||||
|
||||
auto deleter = [dev_idx](cudaEvent_t event) {
|
||||
platform::SetDeviceId(dev_idx);
|
||||
PADDLE_ENFORCE_CUDA_SUCCESS(
|
||||
cudaEventDestroy(event),
|
||||
platform::errors::Fatal(
|
||||
"cudaEventDestroy raises unexpected exception"));
|
||||
};
|
||||
|
||||
pool_.emplace_back(ResourcePool<CudaEventObject>::Create(creator, deleter));
|
||||
}
|
||||
}
|
||||
|
||||
CudaEventResourcePool& CudaEventResourcePool::Instance() {
|
||||
static CudaEventResourcePool pool;
|
||||
return pool;
|
||||
}
|
||||
|
||||
std::shared_ptr<CudaEventObject> CudaEventResourcePool::New(int dev_idx) {
|
||||
PADDLE_ENFORCE_GE(
|
||||
dev_idx, 0,
|
||||
platform::errors::InvalidArgument(
|
||||
"dev_idx should be not less than 0, but got %d", dev_idx));
|
||||
PADDLE_ENFORCE_LT(
|
||||
dev_idx, pool_.size(),
|
||||
platform::errors::OutOfRange(
|
||||
"dev_idx should be less than device count %d, but got %d",
|
||||
pool_.size(), dev_idx));
|
||||
return pool_[dev_idx]->New();
|
||||
}
|
||||
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
||||
|
||||
#endif
|
@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/platform/resource_pool.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
|
||||
using CudaStreamObject = std::remove_pointer<cudaStream_t>::type;
|
||||
using CudaEventObject = std::remove_pointer<cudaEvent_t>::type;
|
||||
|
||||
class CudaStreamResourcePool {
|
||||
public:
|
||||
std::shared_ptr<CudaStreamObject> New(int dev_idx);
|
||||
|
||||
static CudaStreamResourcePool &Instance();
|
||||
|
||||
private:
|
||||
CudaStreamResourcePool();
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(CudaStreamResourcePool);
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<ResourcePool<CudaStreamObject>>> pool_;
|
||||
};
|
||||
|
||||
class CudaEventResourcePool {
|
||||
public:
|
||||
std::shared_ptr<CudaEventObject> New(int dev_idx);
|
||||
|
||||
static CudaEventResourcePool &Instance();
|
||||
|
||||
private:
|
||||
CudaEventResourcePool();
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(CudaEventResourcePool);
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<ResourcePool<CudaEventObject>>> pool_;
|
||||
};
|
||||
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
||||
|
||||
#endif
|
@ -0,0 +1,100 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
#include "paddle/fluid/platform/macros.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
|
||||
template <typename T>
|
||||
class ResourcePool : public std::enable_shared_from_this<ResourcePool<T>> {
|
||||
private:
|
||||
struct ResourceDeleter {
|
||||
public:
|
||||
explicit ResourceDeleter(ResourcePool<T> *pool)
|
||||
: instance_(pool->shared_from_this()) {}
|
||||
|
||||
void operator()(T *ptr) const { instance_->Restore(ptr); }
|
||||
|
||||
private:
|
||||
std::shared_ptr<ResourcePool<T>> instance_;
|
||||
};
|
||||
|
||||
public:
|
||||
static std::shared_ptr<ResourcePool<T>> Create(
|
||||
const std::function<T *()> &creator,
|
||||
const std::function<void(T *)> &deleter) {
|
||||
return std::shared_ptr<ResourcePool<T>>(
|
||||
new ResourcePool<T>(creator, deleter));
|
||||
}
|
||||
|
||||
~ResourcePool() {
|
||||
for (auto *ptr : instances_) {
|
||||
deleter_(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<T> New() {
|
||||
std::lock_guard<std::mutex> guard(mtx_);
|
||||
T *obj = nullptr;
|
||||
if (instances_.empty()) {
|
||||
obj = creator_();
|
||||
PADDLE_ENFORCE_NOT_NULL(obj,
|
||||
platform::errors::PermissionDenied(
|
||||
"The creator should not return nullptr"));
|
||||
VLOG(10) << "Create new instance " << TypePtrName();
|
||||
} else {
|
||||
obj = instances_.back();
|
||||
instances_.pop_back();
|
||||
VLOG(10) << "Pop new instance " << TypePtrName()
|
||||
<< " from pool, size=" << instances_.size();
|
||||
}
|
||||
return std::shared_ptr<T>(obj, ResourceDeleter(this));
|
||||
}
|
||||
|
||||
private:
|
||||
static std::string TypePtrName() {
|
||||
return platform::demangle(typeid(T *).name()); // NOLINT
|
||||
}
|
||||
|
||||
private:
|
||||
ResourcePool(const std::function<T *()> &creator,
|
||||
const std::function<void(T *)> &deleter)
|
||||
: creator_(creator), deleter_(deleter) {}
|
||||
|
||||
void Restore(T *ptr) {
|
||||
std::lock_guard<std::mutex> guard(mtx_);
|
||||
instances_.emplace_back(ptr);
|
||||
VLOG(10) << "Restore " << TypePtrName()
|
||||
<< " into pool, size=" << instances_.size();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<T *> instances_;
|
||||
std::function<T *()> creator_;
|
||||
std::function<void(T *)> deleter_;
|
||||
|
||||
std::mutex mtx_;
|
||||
};
|
||||
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
Loading…
Reference in new issue