|
|
|
@ -22,41 +22,6 @@ namespace memory {
|
|
|
|
|
namespace allocation {
|
|
|
|
|
|
|
|
|
|
BufferedAllocator::BufferedAllocator(std::unique_ptr<Allocator>&& allocator) {
|
|
|
|
|
std::vector<size_t> division_plan(8 * sizeof(size_t));
|
|
|
|
|
for (size_t i = 0; i < 8 * sizeof(size_t); ++i) {
|
|
|
|
|
division_plan[i] = (static_cast<size_t>(1) << i);
|
|
|
|
|
}
|
|
|
|
|
InitAndEnforceCheck(std::move(allocator), division_plan);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferedAllocator::BufferedAllocator(std::unique_ptr<Allocator>&& allocator,
|
|
|
|
|
const std::vector<size_t>& division_plan) {
|
|
|
|
|
InitAndEnforceCheck(std::move(allocator), division_plan);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BufferedAllocator::~BufferedAllocator() { FlushImpl(); }
|
|
|
|
|
|
|
|
|
|
void BufferedAllocator::FlushImpl() {
|
|
|
|
|
for (auto& v : allocations_) {
|
|
|
|
|
for (auto& pair : v) {
|
|
|
|
|
underlying_allocator_->FreeUniquePtr(std::move(pair.second));
|
|
|
|
|
}
|
|
|
|
|
v.clear();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BufferedAllocator::Flush() {
|
|
|
|
|
if (mtx_) {
|
|
|
|
|
std::lock_guard<std::mutex> lock(*mtx_);
|
|
|
|
|
FlushImpl();
|
|
|
|
|
} else {
|
|
|
|
|
FlushImpl();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BufferedAllocator::InitAndEnforceCheck(
|
|
|
|
|
std::unique_ptr<Allocator>&& allocator,
|
|
|
|
|
const std::vector<size_t>& division_plan) {
|
|
|
|
|
underlying_allocator_.reset(
|
|
|
|
|
dynamic_cast<UnmanagedAllocator*>(allocator.release()));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
@ -65,141 +30,54 @@ void BufferedAllocator::InitAndEnforceCheck(
|
|
|
|
|
if (underlying_allocator_->IsAllocThreadSafe()) {
|
|
|
|
|
mtx_.reset(new std::mutex());
|
|
|
|
|
}
|
|
|
|
|
constexpr size_t kMax = std::numeric_limits<size_t>::max();
|
|
|
|
|
if (division_plan.empty()) {
|
|
|
|
|
division_plan_.assign({0, kMax});
|
|
|
|
|
} else {
|
|
|
|
|
auto from = division_plan.front() == 0 ? division_plan.begin() + 1
|
|
|
|
|
: division_plan.begin();
|
|
|
|
|
auto to = division_plan.back() == kMax ? division_plan.end() - 1
|
|
|
|
|
: division_plan.end();
|
|
|
|
|
division_plan_.reserve(to - from + 2);
|
|
|
|
|
division_plan_.push_back(0);
|
|
|
|
|
division_plan_.insert(division_plan_.end(), from, to);
|
|
|
|
|
division_plan_.push_back(kMax);
|
|
|
|
|
for (size_t i = 1; i < division_plan_.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(division_plan_[i - 1], division_plan_[i],
|
|
|
|
|
"Division plan must be strictly sorted");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
allocations_.resize(division_plan_.size() - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BufferedAllocator::InsertAllocationImpl(
|
|
|
|
|
std::unique_ptr<Allocation>&& allocation) {
|
|
|
|
|
auto size = allocation->size();
|
|
|
|
|
auto idx = GetListIndex(size);
|
|
|
|
|
allocations_[idx].emplace(size, std::move(allocation));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BufferedAllocator::InsertAllocation(
|
|
|
|
|
std::unique_ptr<Allocation>&& allocation) {
|
|
|
|
|
if (mtx_) {
|
|
|
|
|
std::lock_guard<std::mutex> lock(*mtx_);
|
|
|
|
|
InsertAllocationImpl(std::move(allocation));
|
|
|
|
|
} else {
|
|
|
|
|
InsertAllocationImpl(std::move(allocation));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool BufferedAllocator::Match(size_t actual_size, size_t requested_size) {
|
|
|
|
|
return (actual_size >> 1) < requested_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t BufferedAllocator::GetListIndex(size_t size) {
|
|
|
|
|
auto it =
|
|
|
|
|
std::upper_bound(division_plan_.begin(), division_plan_.end(), size);
|
|
|
|
|
return static_cast<size_t>(it - division_plan_.begin()) - 1;
|
|
|
|
|
}
|
|
|
|
|
BufferedAllocator::~BufferedAllocator() { FreeCache(-1UL); }
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Allocation> BufferedAllocator::RemoveAllocationImpl(
|
|
|
|
|
size_t size) {
|
|
|
|
|
auto idx = GetListIndex(size);
|
|
|
|
|
auto& allocation_map = allocations_[idx];
|
|
|
|
|
auto it = allocation_map.lower_bound(size);
|
|
|
|
|
// Only remove allocation whose size is not more than twice of requested size
|
|
|
|
|
if (it != allocation_map.end()) {
|
|
|
|
|
if (Match(it->second->size(), size)) {
|
|
|
|
|
auto ret = std::move(it->second);
|
|
|
|
|
allocation_map.erase(it);
|
|
|
|
|
return ret;
|
|
|
|
|
} else {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
while (++idx < allocations_.size() && Match(division_plan_[idx], size)) {
|
|
|
|
|
auto& allocation_map = allocations_[idx];
|
|
|
|
|
if (!allocation_map.empty()) {
|
|
|
|
|
auto it = allocation_map.begin();
|
|
|
|
|
if (Match(it->second->size(), size)) {
|
|
|
|
|
auto ret = std::move(it->second);
|
|
|
|
|
allocation_map.erase(it);
|
|
|
|
|
return ret;
|
|
|
|
|
} else {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::unique_ptr<Allocation> BufferedAllocator::Allocate(size_t size,
|
|
|
|
|
Allocator::Attr attr) {
|
|
|
|
|
std::unique_ptr<Allocation> result;
|
|
|
|
|
{
|
|
|
|
|
platform::LockGuardPtr<std::mutex> guard(mtx_);
|
|
|
|
|
auto it = allocations_.lower_bound(size);
|
|
|
|
|
if (it != allocations_.end() && it->first < size * 2) {
|
|
|
|
|
result = std::move(it->second);
|
|
|
|
|
allocations_.erase(it);
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Allocation> BufferedAllocator::RemoveAllocation(size_t size) {
|
|
|
|
|
if (mtx_) {
|
|
|
|
|
std::lock_guard<std::mutex> lock(*mtx_);
|
|
|
|
|
return RemoveAllocationImpl(size);
|
|
|
|
|
} else {
|
|
|
|
|
return RemoveAllocationImpl(size);
|
|
|
|
|
if (result) {
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Allocation> BufferedAllocator::Allocate(size_t size,
|
|
|
|
|
Allocator::Attr attr) {
|
|
|
|
|
auto ret = RemoveAllocation(size);
|
|
|
|
|
if (!ret) {
|
|
|
|
|
try {
|
|
|
|
|
return underlying_allocator_->Allocate(size, attr);
|
|
|
|
|
} catch (BadAlloc&) {
|
|
|
|
|
// if allocation failed, try to free some memorys from buffers
|
|
|
|
|
FreeAllocations(size);
|
|
|
|
|
return underlying_allocator_->Allocate(size, attr);
|
|
|
|
|
}
|
|
|
|
|
try {
|
|
|
|
|
return underlying_allocator_->Allocate(size, attr);
|
|
|
|
|
} catch (BadAlloc&) {
|
|
|
|
|
FreeCache(size);
|
|
|
|
|
return underlying_allocator_->Allocate(size, attr);
|
|
|
|
|
}
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BufferedAllocator::FreeAllocationsImpl(size_t size) {
|
|
|
|
|
void BufferedAllocator::FreeCache(size_t size) {
|
|
|
|
|
platform::LockGuardPtr<std::mutex> guard(mtx_);
|
|
|
|
|
if (UNLIKELY(size == 0)) return;
|
|
|
|
|
size_t cur = 0;
|
|
|
|
|
for (auto& alloc_map : allocations_) {
|
|
|
|
|
// use reverse iterator to free large allocations first
|
|
|
|
|
while (!alloc_map.empty()) {
|
|
|
|
|
auto it = --(alloc_map.end());
|
|
|
|
|
cur += it->second->size();
|
|
|
|
|
underlying_allocator_->FreeUniquePtr(std::move(it->second));
|
|
|
|
|
alloc_map.erase(it);
|
|
|
|
|
if (cur >= size) return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BufferedAllocator::FreeAllocations(size_t size) {
|
|
|
|
|
if (mtx_) {
|
|
|
|
|
std::lock_guard<std::mutex> lock(*mtx_);
|
|
|
|
|
FreeAllocationsImpl(size);
|
|
|
|
|
} else {
|
|
|
|
|
FreeAllocationsImpl(size);
|
|
|
|
|
while (!allocations_.empty()) { // free the largest
|
|
|
|
|
auto it = --allocations_.end();
|
|
|
|
|
cur += it->second->size();
|
|
|
|
|
underlying_allocator_->FreeUniquePtr(std::move(it->second));
|
|
|
|
|
allocations_.erase(it);
|
|
|
|
|
if (cur >= size) return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BufferedAllocator::FreeUniquePtr(std::unique_ptr<Allocation> allocation) {
|
|
|
|
|
InsertAllocation(std::move(allocation));
|
|
|
|
|
platform::LockGuardPtr<std::mutex> guard(mtx_);
|
|
|
|
|
allocations_.emplace(allocation->size(), std::move(allocation));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool BufferedAllocator::IsAllocThreadSafe() const { return mtx_ != nullptr; }
|
|
|
|
|
|
|
|
|
|
const std::vector<size_t>& BufferedAllocator::GetDivisionPlan() const {
|
|
|
|
|
return division_plan_;
|
|
|
|
|
bool BufferedAllocator::IsAllocThreadSafe() const {
|
|
|
|
|
return this->underlying_allocator_->IsAllocThreadSafe();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace allocation
|
|
|
|
|