|
|
|
@ -17,9 +17,25 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace memory {
|
|
|
|
|
namespace allocation {
|
|
|
|
|
bool AutoIncrementAllocator::IsAllocThreadSafe() const { return true; }
|
|
|
|
|
|
|
|
|
|
AllocationPtr AutoIncrementAllocator::Allocate(size_t size,
|
|
|
|
|
Allocator::Attr attr) {
|
|
|
|
|
std::shared_ptr<Allocator> AutoIncrementAllocator::CreateNewAllocator() {
|
|
|
|
|
std::lock_guard<std::mutex> guard(mtx_);
|
|
|
|
|
auto old_size = allocator_num_.load();
|
|
|
|
|
PADDLE_ENFORCE_LT(old_size, underlying_allocators_.size(),
|
|
|
|
|
"Allocator number exceeds capacity %d",
|
|
|
|
|
underlying_allocators_.size());
|
|
|
|
|
underlying_allocators_[old_size] = creator_();
|
|
|
|
|
prev_success_allocator_ = old_size;
|
|
|
|
|
++allocator_num_;
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
underlying_allocators_[old_size]->IsAllocThreadSafe(),
|
|
|
|
|
"the underlying allocator must be thread safe. This is a program "
|
|
|
|
|
"bug.");
|
|
|
|
|
return underlying_allocators_[old_size];
|
|
|
|
|
}
|
|
|
|
|
Allocation *AutoIncrementAllocator::AllocateImpl(size_t size,
|
|
|
|
|
Allocator::Attr attr) {
|
|
|
|
|
auto cur = prev_success_allocator_.load();
|
|
|
|
|
size_t retry_count = allocator_num_.load();
|
|
|
|
|
size_t allocator_num = retry_count;
|
|
|
|
@ -27,8 +43,8 @@ AllocationPtr AutoIncrementAllocator::Allocate(size_t size,
|
|
|
|
|
try {
|
|
|
|
|
auto res = underlying_allocators_[cur]->Allocate(size, attr);
|
|
|
|
|
prev_success_allocator_ = cur;
|
|
|
|
|
return res;
|
|
|
|
|
} catch (BadAlloc&) {
|
|
|
|
|
return res.release();
|
|
|
|
|
} catch (BadAlloc &) {
|
|
|
|
|
if (++cur >= allocator_num) {
|
|
|
|
|
cur = 0;
|
|
|
|
|
}
|
|
|
|
@ -47,32 +63,14 @@ AllocationPtr AutoIncrementAllocator::Allocate(size_t size,
|
|
|
|
|
try {
|
|
|
|
|
auto ret = underlying_allocators_[cur]->Allocate(size, attr);
|
|
|
|
|
prev_success_allocator_ = cur;
|
|
|
|
|
return ret;
|
|
|
|
|
} catch (BadAlloc&) {
|
|
|
|
|
return ret.release();
|
|
|
|
|
} catch (BadAlloc &) {
|
|
|
|
|
} catch (...) {
|
|
|
|
|
throw;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// No suitable allocator
|
|
|
|
|
return CreateNewAllocator()->Allocate(size, attr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AutoIncrementAllocator::IsAllocThreadSafe() const { return true; }
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<Allocator> AutoIncrementAllocator::CreateNewAllocator() {
|
|
|
|
|
std::lock_guard<std::mutex> guard(mtx_);
|
|
|
|
|
auto old_size = allocator_num_.load();
|
|
|
|
|
PADDLE_ENFORCE_LT(old_size, underlying_allocators_.size(),
|
|
|
|
|
"Allocator number exceeds capacity %d",
|
|
|
|
|
underlying_allocators_.size());
|
|
|
|
|
underlying_allocators_[old_size] = creator_();
|
|
|
|
|
prev_success_allocator_ = old_size;
|
|
|
|
|
++allocator_num_;
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
underlying_allocators_[old_size]->IsAllocThreadSafe(),
|
|
|
|
|
"the underlying allocator must be thread safe. This is a program "
|
|
|
|
|
"bug.");
|
|
|
|
|
return underlying_allocators_[old_size];
|
|
|
|
|
return CreateNewAllocator()->Allocate(size, attr).release();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace allocation
|
|
|
|
|