|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <atomic> // NOLINT
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <thread> // NOLINT
|
|
|
|
@ -55,44 +56,61 @@ class AutoIncrementAllocator : public ManagedAllocator {
|
|
|
|
|
template <typename Callback>
|
|
|
|
|
inline typename std::result_of<Callback(ManagedAllocator&)>::type
|
|
|
|
|
InvokeOrCreateUnderlyingAllocator(Callback callback) {
|
|
|
|
|
size_t retry_count = underlying_allocators_.size();
|
|
|
|
|
auto cur = prev_success_allocator_;
|
|
|
|
|
std::shared_ptr<std::vector<AllocatorCreator::result_type>>
|
|
|
|
|
underlying_allocators = underlying_allocators_;
|
|
|
|
|
size_t retry_count = underlying_allocators->size();
|
|
|
|
|
size_t allocator_num = retry_count;
|
|
|
|
|
auto cur = prev_success_allocator_.load();
|
|
|
|
|
while (retry_count-- > 0) { // until there retry count is zero
|
|
|
|
|
try {
|
|
|
|
|
auto res = callback(*underlying_allocators_[cur]);
|
|
|
|
|
{
|
|
|
|
|
std::lock_guard<std::mutex> guard(mtx_);
|
|
|
|
|
prev_success_allocator_ = cur;
|
|
|
|
|
}
|
|
|
|
|
auto res = callback(*((*underlying_allocators)[cur]));
|
|
|
|
|
prev_success_allocator_.store(cur);
|
|
|
|
|
return std::move(res);
|
|
|
|
|
} catch (BadAlloc&) {
|
|
|
|
|
++cur;
|
|
|
|
|
if (cur >= underlying_allocators_.size()) {
|
|
|
|
|
if (++cur >= allocator_num) {
|
|
|
|
|
cur = 0;
|
|
|
|
|
}
|
|
|
|
|
} catch (...) {
|
|
|
|
|
// if there is another type of allocation, just rethrow it.
|
|
|
|
|
throw;
|
|
|
|
|
std::rethrow_exception(std::current_exception());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// No suitable allocator
|
|
|
|
|
|
|
|
|
|
ManagedAllocator* new_allocator;
|
|
|
|
|
{
|
|
|
|
|
std::lock_guard<std::mutex> guard(mtx_);
|
|
|
|
|
underlying_allocators_.emplace_back(creator_());
|
|
|
|
|
prev_success_allocator_ = underlying_allocators_.size() - 1;
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
underlying_allocators_[prev_success_allocator_]->IsAllocThreadSafe(),
|
|
|
|
|
"the underlying allocator must be thread safe. This is a program "
|
|
|
|
|
"bug.");
|
|
|
|
|
auto old_size = underlying_allocators_->size();
|
|
|
|
|
decltype(underlying_allocators_) new_allocators(
|
|
|
|
|
new std::vector<AllocatorCreator::result_type>(old_size + 1));
|
|
|
|
|
for (size_t i = 0; i < old_size; ++i) {
|
|
|
|
|
(*new_allocators)[i] = (*underlying_allocators_)[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return callback(*underlying_allocators_[prev_success_allocator_]);
|
|
|
|
|
(*new_allocators)[old_size] = creator_();
|
|
|
|
|
new_allocator = (*new_allocators)[old_size].get();
|
|
|
|
|
underlying_allocators_ = new_allocators;
|
|
|
|
|
prev_success_allocator_.store(old_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
new_allocator->IsAllocThreadSafe(),
|
|
|
|
|
"the underlying allocator must be thread safe. This is a program "
|
|
|
|
|
"bug.");
|
|
|
|
|
return callback(*new_allocator);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AllocatorCreator creator_;
|
|
|
|
|
std::vector<AllocatorCreator::result_type> underlying_allocators_;
|
|
|
|
|
size_t prev_success_allocator_{0};
|
|
|
|
|
std::mutex mtx_; // NOLINT
|
|
|
|
|
|
|
|
|
|
// Use std::shared_ptr to ensure thread-safety
|
|
|
|
|
std::shared_ptr<std::vector<AllocatorCreator::result_type>>
|
|
|
|
|
underlying_allocators_;
|
|
|
|
|
|
|
|
|
|
// Use std::atomic rather than std::mutex, since std::atomic is usually
|
|
|
|
|
// lock-free
|
|
|
|
|
std::atomic<size_t> prev_success_allocator_{0};
|
|
|
|
|
|
|
|
|
|
std::mutex mtx_;
|
|
|
|
|
};
|
|
|
|
|
} // namespace allocation
|
|
|
|
|
} // namespace memory
|
|
|
|
|