You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/ccsrc/dataset/util/allocator.h

178 lines
5.7 KiB

/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_ALLOCATOR_H_
#define DATASET_UTIL_ALLOCATOR_H_
#include <cstdlib>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include "dataset/util/memory_pool.h"
namespace mindspore {
namespace dataset {
// The following conforms to the requirements of
// std::allocator. Do not rename/change any needed
// requirements, e.g. function names, typedef etc.
template <typename T>
class Allocator {
public:
template <typename U>
friend class Allocator;
using value_type = T;
using pointer = T *;
using const_pointer = const T *;
using reference = T &;
using const_reference = const T &;
using size_type = uint64_t;
template <typename U>
struct rebind {
using other = Allocator<U>;
};
using propagate_on_container_copy_assignment = std::true_type;
using propagate_on_container_move_assignment = std::true_type;
using propagate_on_container_swap = std::true_type;
explicit Allocator(const std::shared_ptr<MemoryPool> &b) : pool_(b) {}
~Allocator() = default;
template <typename U>
explicit Allocator(Allocator<U> const &rhs) : pool_(rhs.pool_) {}
template <typename U>
bool operator==(Allocator<U> const &rhs) const {
return pool_ == rhs.pool_;
}
template <typename U>
bool operator!=(Allocator<U> const &rhs) const {
return pool_ != rhs.pool_;
}
pointer allocate(std::size_t n) {
void *p;
Status rc = pool_->Allocate(n * sizeof(T), &p);
if (rc.IsOk()) {
return reinterpret_cast<pointer>(p);
} else if (rc.IsOutofMemory()) {
throw std::bad_alloc();
} else {
throw std::exception();
}
}
void deallocate(pointer p, std::size_t n = 0) noexcept { pool_->Deallocate(p); }
size_type max_size() { return pool_->get_max_size(); }
private:
std::shared_ptr<MemoryPool> pool_;
};
/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will
/// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator.
/// Default to std::allocator
template <typename T, typename C = std::allocator<T>>
class MemGuard {
public:
using allocator = C;
MemGuard() : n_(0) {}
explicit MemGuard(allocator a) : n_(0), alloc_(a) {}
// There is no copy constructor nor assignment operator because the memory is solely owned by this object.
MemGuard(const MemGuard &) = delete;
MemGuard &operator=(const MemGuard &) = delete;
// On the other hand, We can support move constructor
MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {}
MemGuard &operator=(MemGuard &&lhs) noexcept {
if (this != &lhs) {
this->deallocate();
n_ = lhs.n_;
alloc_ = std::move(lhs.alloc_);
ptr_ = std::move(lhs.ptr_);
}
return *this;
}
/// \brief Explicitly deallocate the memory if allocated
void deallocate() {
if (ptr_) {
auto *p = ptr_.release();
if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) {
for (auto i = 0; i < n_; ++i) {
p[i].~T();
}
}
alloc_.deallocate(p, n_);
n_ = 0;
}
}
/// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is
/// allocated.
/// \param n Number of objects of type T to be allocated
/// \tparam Args Extra arguments pass to the constructor of T
template <typename... Args>
Status allocate(size_t n, Args &&... args) noexcept {
try {
deallocate();
if (n > 0) {
T *data = alloc_.allocate(n);
if (!std::is_arithmetic<T>::value) {
for (auto i = 0; i < n; i++) {
std::allocator_traits<C>::construct(alloc_, &(data[i]), std::forward<Args>(args)...);
}
}
ptr_ = std::unique_ptr<T[]>(data);
n_ = n;
}
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
} catch (std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
return Status::OK();
}
~MemGuard() noexcept { deallocate(); }
/// \brief Getter function
/// \return The pointer to the memory allocated
T *GetPointer() const { return ptr_.get(); }
/// \brief Getter function
/// \return The pointer to the memory allocated
T *GetMutablePointer() { return ptr_.get(); }
/// \brief Overload [] operator to access a particular element
/// \param x index to the element. Must be less than number of element allocated.
/// \return pointer to the x-th element
T *operator[](size_t x) { return GetMutablePointer() + x; }
/// \brief Overload [] operator to access a particular element
/// \param x index to the element. Must be less than number of element allocated.
/// \return pointer to the x-th element
T *operator[](size_t x) const { return GetPointer() + x; }
/// \brief Return how many bytes are allocated in total
/// \return Number of bytes allocated in total
size_t GetSizeInBytes() const { return n_ * sizeof(T); }
private:
allocator alloc_;
std::unique_ptr<T[], std::function<void(T *)>> ptr_;
size_t n_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_UTIL_ALLOCATOR_H_