Async write buffer

pull/8926/head
Jesse Lee 4 years ago
parent dd86f0234d
commit e59b5f3a4a

@ -14,34 +14,96 @@
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) : val_in_GB_(val_in_GB), port_(port) {
CachedSharedMemory::CachedSharedMemory(int32_t port, size_t val_in_GB)
: shared_memory_sz_in_gb_(val_in_GB), port_(port), num_numa_nodes_(-1), sub_pool_sz_(-1) {
// We create the shared memory and we will destroy it. All other client just detach only.
shm_.RemoveResourcesOnExit();
}
CachedSharedMemoryArena::~CachedSharedMemoryArena() {}
CachedSharedMemory::~CachedSharedMemory() = default;
Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port,
size_t val_in_GB) {
RETURN_UNEXPECTED_IF_NULL(out);
auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB);
if (ba == nullptr) {
return Status(StatusCode::kOutOfMemory);
}
// Transfer the ownership of this pointer. Any future error in the processing we will have
// the destructor of *out to deal.
(*out).reset(ba);
Status CachedSharedMemory::Init() {
CacheServer &cs = CacheServer::GetInstance();
num_numa_nodes_ = cs.GetNumaNodeCount();
// Generate the ftok using a combination of port.
SharedMemory::shm_key_t shm_key;
RETURN_IF_NOT_OK(PortToFtok(port, &shm_key));
ba->shm_.SetPublicKey(shm_key);
RETURN_IF_NOT_OK(PortToFtok(port_, &shm_key));
shm_.SetPublicKey(shm_key);
// Value is in GB. Convert into bytes.
int64_t sz = val_in_GB * 1073741824L;
RETURN_IF_NOT_OK(ba->shm_.Create(sz));
ba->impl_ = std::make_unique<ArenaImpl>(ba->shm_.SharedMemoryBaseAddr(), sz);
int64_t shm_mem_sz = shared_memory_sz_in_gb_ * 1073741824L;
RETURN_IF_NOT_OK(shm_.Create(shm_mem_sz));
MS_LOG(INFO) << "Creation of shared memory successful. Shared memory key " << shm_.GetKey();
// Interleave the memory.
cs.GetHWControl()->InterleaveMemory(shm_.SharedMemoryBaseAddr(), shm_mem_sz);
// We will create a number of sub pool out of shared memory to reduce latch contention
int32_t num_of_pools = num_numa_nodes_;
if (num_numa_nodes_ == 1) {
num_of_pools = shared_memory_sz_in_gb_ * 2;
}
sub_pool_sz_ = shm_mem_sz / num_of_pools;
// If each subpool is too small, readjust the number of pools
constexpr int64 min_subpool_sz = 512 * 1048576L;
if (sub_pool_sz_ < min_subpool_sz) {
sub_pool_sz_ = min_subpool_sz;
num_of_pools = shm_mem_sz / min_subpool_sz;
}
shm_pool_.reserve(num_of_pools);
for (auto i = 0; i < num_of_pools; ++i) {
void *ptr = static_cast<char *>(shm_.SharedMemoryBaseAddr()) + i * sub_pool_sz_;
shm_pool_.push_back(std::make_unique<ArenaImpl>(ptr, sub_pool_sz_));
}
mux_ = std::make_unique<std::mutex[]>(num_of_pools);
return Status::OK();
}
Status CachedSharedMemory::CreateArena(std::unique_ptr<CachedSharedMemory> *out, int32_t port, size_t val_in_GB) {
RETURN_UNEXPECTED_IF_NULL(out);
auto mem_pool = std::unique_ptr<CachedSharedMemory>(new CachedSharedMemory(port, val_in_GB));
RETURN_IF_NOT_OK(mem_pool->Init());
*out = std::move(mem_pool);
return Status::OK();
}
Status CachedSharedMemory::AllocateSharedMemory(int32_t client_id, size_t sz, void **p) {
Status rc;
RETURN_UNEXPECTED_IF_NULL(p);
auto begin_slot = client_id % shm_pool_.size();
auto slot = begin_slot;
do {
std::unique_lock<std::mutex> lock(mux_[slot]);
rc = shm_pool_[slot]->Allocate(sz, p);
if (rc.IsOutofMemory()) {
slot = (slot + 1) % shm_pool_.size();
}
} while (rc.IsError() && slot != begin_slot);
if (rc.IsError()) {
return rc;
}
return Status::OK();
}
void CachedSharedMemory::DeallocateSharedMemory(int32_t client_id, void *p) {
auto begin_slot = client_id % shm_pool_.size();
auto slot = begin_slot;
auto start_addr = static_cast<char *>(SharedMemoryBaseAddr());
bool found = false;
do {
auto ptr = start_addr + slot * sub_pool_sz_;
if (ptr <= p && p < (ptr + sub_pool_sz_)) {
std::unique_lock<std::mutex> lock(mux_[slot]);
shm_pool_[slot]->Deallocate(p);
found = true;
break;
} else {
slot = (slot + 1) % shm_pool_.size();
}
} while (slot != begin_slot);
if (!found) {
MS_LOG(ERROR) << "Programming error. Can't find the arena the pointer " << p << " comes from";
}
}
} // namespace dataset
} // namespace mindspore

@ -18,25 +18,29 @@
#include <memory>
#include <mutex>
#include <vector>
#include <string>
#include <utility>
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
namespace mindspore {
namespace dataset {
/// This is a derived class of Arena but resides in shared memory
class CachedSharedMemoryArena : public MemoryPool {
/// This is like a CircularPool but each arena is in shared memory and
/// possibly bind to a numa socket.
class CachedSharedMemory {
public:
// Disable copy and assignment constructor
CachedSharedMemoryArena(const CachedSharedMemoryArena &) = delete;
CachedSharedMemoryArena &operator=(const CachedSharedMemoryArena &) = delete;
~CachedSharedMemoryArena() override;
CachedSharedMemory(const CachedSharedMemory &) = delete;
CachedSharedMemory &operator=(const CachedSharedMemory &) = delete;
~CachedSharedMemory();
/// \brief Create an Arena in shared memory
/// \param[out] p_ba Pointer to a unique_ptr
/// \param shmkey Shared memory key
/// \param val_in_GB size of shared memory in gigabyte
/// \return Status object
static Status CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, size_t val_in_GB);
static Status CreateArena(std::unique_ptr<CachedSharedMemory> *out, int32_t port, size_t val_in_GB);
/// \brief This returns where we attach to the shared memory.
/// Some gRPC requests will ask for a shared memory block, and
@ -44,45 +48,29 @@ class CachedSharedMemoryArena : public MemoryPool {
/// in the client. So instead we will return an address relative
/// to the base address of the shared memory where we attach to.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return impl_->get_base_addr(); }
/// As a derived class of MemoryPool, we have to implement the following
/// But we simply transfer the call to the implementation class
Status Allocate(size_t size, void **pVoid) override {
std::unique_lock<std::mutex> lock(mux_);
return impl_->Allocate(size, pVoid);
}
Status Reallocate(void **pVoid, size_t old_sz, size_t new_sz) override {
std::unique_lock<std::mutex> lock(mux_);
return impl_->Reallocate(pVoid, old_sz, new_sz);
}
void Deallocate(void *pVoid) override {
std::unique_lock<std::mutex> lock(mux_);
impl_->Deallocate(pVoid);
}
uint64_t get_max_size() const override { return impl_->get_max_size(); }
int PercentFree() const override {
std::unique_lock<std::mutex> lock(mux_);
return impl_->PercentFree();
}
/// \brief Dump the memory allocation block.
friend std::ostream &operator<<(std::ostream &os, const CachedSharedMemoryArena &s) {
os << *(s.impl_);
return os;
}
const void *SharedMemoryBaseAddr() const { return shm_.SharedMemoryBaseAddr(); }
void *SharedMemoryBaseAddr() { return shm_.SharedMemoryBaseAddr(); }
/// \brief Get the shared memory key of the shared memory
SharedMemory::shm_key_t GetKey() const { return shm_.GetKey(); }
/// \brief Allocate shared memory for a given pipeline
Status AllocateSharedMemory(int32_t client_id, size_t sz, void **p);
/// \brief Deallocate shared memory for a given pipeline
void DeallocateSharedMemory(int32_t client_id, void *p);
private:
mutable std::mutex mux_;
int32_t val_in_GB_;
int32_t shared_memory_sz_in_gb_;
int32_t port_;
SharedMemory shm_;
std::unique_ptr<ArenaImpl> impl_;
std::vector<std::unique_ptr<ArenaImpl>> shm_pool_;
std::unique_ptr<std::mutex[]> mux_;
int32_t num_numa_nodes_;
int64_t sub_pool_sz_;
/// Private constructor. Not to be called directly.
CachedSharedMemoryArena(int32_t port, size_t val_in_GB);
CachedSharedMemory(int32_t port, size_t val_in_GB);
Status Init();
};
} // namespace dataset
} // namespace mindspore

File diff suppressed because it is too large Load Diff

@ -38,6 +38,8 @@
#include "minddata/dataset/util/lock.h"
#include "minddata/dataset/util/cond_var.h"
#include "minddata/dataset/util/queue_map.h"
#include "minddata/dataset/util/task_manager.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
@ -50,6 +52,7 @@ class CacheClient {
friend class CreateCacheRequest;
friend class CacheRowRequest;
friend class BatchFetchRequest;
friend class BatchCacheRowsRequest;
/// \brief A builder to help creating a CacheClient object
class Builder {
@ -180,6 +183,11 @@ class CacheClient {
/// \return Status object
Status GetStat(CacheServiceStat *);
/// \brief Get the state of a cache server
/// \param[in/out] Pointer to a int8_t
/// \return Status object
Status GetState(int8_t *);
/// \brief Cache the schema at the cache server
/// \param map The unordered map of the schema
/// \return Status object
@ -230,6 +238,7 @@ class CacheClient {
int32_t GetPort() const { return port_; }
int32_t GetNumConnections() const { return num_connections_; }
int32_t GetPrefetchSize() const { return prefetch_size_; }
int32_t GetClientId() const { return client_id_; }
/// MergeOp will notify us when the server can't cache any more rows.
/// We will stop any attempt to fetch any rows that are most likely
@ -250,6 +259,20 @@ class CacheClient {
return false;
}
// Default size of the async write buffer
constexpr static int64_t kAsyncBufferSize = 16 * 1048576L; // 16M
constexpr static int32_t kNumAsyncBuffer = 2;
/// Force a final flush to the cache server. Must be called when receving eoe.
Status FlushAsyncWriteBuffer() {
if (async_buffer_stream_) {
return async_buffer_stream_->SyncFlush(true);
}
return Status::OK();
}
Status AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in);
private:
mutable RWLock mux_;
uint64_t cache_mem_sz_;
@ -288,6 +311,62 @@ class CacheClient {
std::set<row_id_type> gap_;
};
std::unique_ptr<CacheMissKeys> cache_miss_keys_;
/// A data stream of back-to-back serialized tensor rows.
class AsyncBufferStream {
public:
AsyncBufferStream();
~AsyncBufferStream();
/// \brief Initialize an Ascyn write buffer
Status Init(CacheClient *cc);
/// A worker will call the API AsyncWrite to put a TensorRow into the data stream.
/// A background thread will stream the data to the cache server.
/// The result of calling AsyncWrite is not immediate known or it can be the last
/// result of some previous flush.
/// \note Need to call SyncFlush to do the final flush.
Status AsyncWrite(const TensorRow &row);
Status SyncFlush(bool blocking = false);
/// This maps a physical shared memory to the data stream.
class AsyncWriter {
public:
friend class AsyncBufferStream;
Status Write(int64_t start_addr, int64_t sz, const std::vector<ReadableSlice> &v);
private:
std::shared_ptr<BatchCacheRowsRequest> rq;
void *buffer_;
int32_t num_ele_; // How many tensor rows in this buffer
int64_t begin_addr_; // Start of logical address of the data stream
std::atomic<int64_t> end_addr_; // End of the logical address of the data stream
std::atomic<int64_t> bytes_avail_; // Number of bytes remain
};
/// \brief Release the shared memory during shutdown
/// /note but needs comm layer to be alive.
Status ReleaseBuffer();
private:
Status flush_rc_;
WaitPost writer_wp_;
WaitPost flush_wp_;
RWLock mux_;
TaskGroup vg_;
CacheClient *cc_;
int64_t offset_addr_;
AsyncWriter buf_arr_[kNumAsyncBuffer];
int32_t cur_;
std::atomic<int64_t> next_addr_;
/// \brief Entry point of the async flush thread.
Status AsyncFlush();
};
std::shared_ptr<AsyncBufferStream> async_buffer_stream_;
/// \brief Serialize a Tensor into the async buffer.
Status AsyncWriteRow(const TensorRow &row);
};
} // namespace dataset
} // namespace mindspore

@ -37,6 +37,10 @@ namespace dataset {
/// For too small amount, we won't get any benefit using shared memory method because we need
/// two rpc requests to use shared memory method.
constexpr static int32_t kLocalByPassThreshold = 64 * 1024;
/// \brief Default size (in GB) of shared memory we are going to create
constexpr static int32_t kDefaultSharedMemorySize = 4;
/// \brief Memory Cap ratio used by the server
constexpr static float kDefaultMemoryCapRatio = 0.8;
/// \brief A flag used by the BatchFetch request (client side) if it can support local bypass
constexpr static uint32_t kLocalClientSupport = 1;
/// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is
@ -46,7 +50,15 @@ constexpr static uint32_t kDataIsInSharedMemory = 2;
constexpr static int32_t kSharedMessageSize = 2048;
/// \brief State of CacheService at the server.
enum class CacheServiceState : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking };
enum class CacheServiceState : int8_t {
kNone = 0,
kBuildPhase = 1,
kFetchPhase = 2,
kNoLocking = 3,
kOutOfMemory = 4,
kNoSpace = 5,
kError = 127
};
/// \brief Convert a Status object into a protobuf
/// \param rc[in] Status object

@ -23,8 +23,7 @@
namespace mindspore {
namespace dataset {
CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb)
: port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb), shm_key_(-1) {
CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port) : port_(port) {
// Setup a path for unix socket.
unix_socket_ = PortToUnixSocketPath(port);
// We can't generate the ftok key yet until the unix_socket_ is created
@ -70,14 +69,6 @@ Status CacheServerGreeterImpl::Run() {
server_ = builder.BuildAndStart();
if (server_) {
MS_LOG(INFO) << "Server listening on " << server_address;
#if CACHE_LOCAL_CLIENT
RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_));
shm_key_ = shm_pool_->GetKey();
MS_LOG(INFO) << "Creation of local socket and shared memory successful. Shared memory key " << shm_key_;
auto cs = CacheServer::GetInstance().GetHWControl();
// This shared memory is a hot memory and we will interleave among all the numa nodes.
cs->InterleaveMemory(const_cast<void *>(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L);
#endif
} else {
std::string errMsg = "Fail to start server. ";
if (port_tcpip != port_) {
@ -147,14 +138,18 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp
// Now we pass the address of this instance to CacheServer's main loop.
MS_LOG(DEBUG) << "Handle request " << *this;
// We will distribute the request evenly (or randomly) over all the numa nodes.
// The exception is BatchFetch which we need to pre-process here.
if (type_ == BaseRequest::RequestType::kBatchFetchRows) {
rc_ = cs.BatchFetchRows(&rq_, &reply_);
if (!rc_.IsInterrupted()) {
Status2CacheReply(rc_, &reply_);
st_ = CacheServerRequest::STATE::FINISH;
responder_.Finish(reply_, grpc::Status::OK, this);
} else {
// The exception is BatchFetch and BatchCache which we need to pre-process here.
// Also some requests are urgent that we want to process them here too.
if (type_ == BaseRequest::RequestType::kBatchFetchRows || type_ == BaseRequest::RequestType::kBatchCacheRows ||
type_ == BaseRequest::RequestType::kStopService || type_ == BaseRequest::RequestType::kAllocateSharedBlock ||
type_ == BaseRequest::RequestType::kFreeSharedBlock) {
cs.ProcessRequest(this);
// For cache_admin --stop, ProcessRequest is just acknowledging we receive the request. Now
// we call the real function.
if (type_ == BaseRequest::RequestType::kStopService) {
cs.GlobalShutdown();
return Status(StatusCode::kInterrupted);
} else if (rc_.IsInterrupted()) {
return rc_;
}
} else {
@ -191,10 +186,12 @@ Status CacheServerGreeterImpl::MonitorUnixSocket() {
// If the unix socket is recreated for whatever reason, this server instance will be stale and
// no other process and communicate with us. In this case we need to shutdown ourselves.
if (p.Exists()) {
auto &cs = CacheServer::GetInstance();
SharedMemory::shm_key_t key;
RETURN_IF_NOT_OK(PortToFtok(port_, &key));
if (key != shm_key_) {
std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key_) +
auto shm_key = cs.GetKey();
if (key != shm_key) {
std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key) +
". New key " + std::to_string(key) + ". Shutting down server";
MS_LOG(ERROR) << errMsg;
RETURN_STATUS_UNEXPECTED(errMsg);

@ -18,12 +18,14 @@
#include <atomic>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task_manager.h"
@ -75,7 +77,7 @@ class CacheServerGreeterImpl final {
friend class CacheServer;
public:
explicit CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb);
explicit CacheServerGreeterImpl(int32_t port);
virtual ~CacheServerGreeterImpl();
/// \brief Brings up gRPC server
/// \return none
@ -83,24 +85,18 @@ class CacheServerGreeterImpl final {
/// \brief Entry function to handle cache server request
Status HandleRequest(int32_t worker_id);
/// Return the shared memory pool.
/// \return Return the shared memory pool
CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); }
/// \brief Montor the status of the unix socket in case it is gone.
Status MonitorUnixSocket();
/// \brief This shutdown down the comm layer
void Shutdown();
private:
int32_t port_;
size_t shm_pool_sz_in_gb_;
std::string unix_socket_;
CacheServerGreeter::AsyncService svc_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> server_;
std::unique_ptr<CachedSharedMemoryArena> shm_pool_;
SharedMemory::shm_key_t shm_key_;
};
} // namespace dataset
} // namespace mindspore

@ -209,6 +209,14 @@ void CacheServerHW::InterleaveMemory(void *ptr, size_t sz) {
#endif
}
void CacheServerHW::AssignToNode(numa_id_t numa_id, void *ptr, size_t sz) {
#ifdef NUMA_ENABLED
if (numa_enabled()) {
numa_tonode_memory(ptr, sz, numa_id);
}
#endif
}
bool CacheServerHW::numa_enabled() {
#ifdef NUMA_ENABLED
return (numa_available() != -1);

@ -63,6 +63,9 @@ class CacheServerHW {
/// \brief Interleave a given memory block. Used by shared memory only.
static void InterleaveMemory(void *ptr, size_t sz);
/// \brief Assign a given memory block to a numa node. Used by shared memory only.
void AssignToNode(numa_id_t numa_id, void *ptr, size_t sz);
/// \brief Set default memory policy.
static Status SetDefaultMemoryPolicy(CachePoolPolicy);

@ -53,7 +53,7 @@ Status SharedMessage::Create() {
Status SharedMessage::SendStatus(const Status &rc) {
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id");
StatusMsgBuf msg{
CacheMsgBuf msg{
1,
};
msg.body.status.err_code = static_cast<int32_t>(rc.get_code());
@ -71,7 +71,7 @@ Status SharedMessage::SendStatus(const Status &rc) {
Status SharedMessage::ReceiveStatus(Status *rc) {
RETURN_UNEXPECTED_IF_NULL(rc);
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id");
struct StatusMsgBuf msg {};
struct CacheMsgBuf msg {};
auto err = msgrcv(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), 0, MSG_NOERROR);
if (err == -1) {
std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno);

@ -28,7 +28,7 @@
namespace mindspore {
namespace dataset {
/// A message queue structure between the parent and the child process
struct StatusMsgBuf {
struct CacheMsgBuf {
int64_t mtype;
union {
char mtext[1];

@ -168,12 +168,14 @@ Path CachePool::GetSpillPath() const {
}
CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
tree_->LockShared(); // Prevent any node split while we search.
CacheStat cs{-1, -1, 0, 0, 0, 0};
int64_t total_sz = 0;
if (tree_->begin() != tree_->end()) {
cs.min_key = tree_->begin().key();
cs.max_key = cs.min_key; // will adjust later.
for (auto it = tree_->begin(); it != tree_->end(); ++it) {
it.LockShared();
total_sz += it.value().sz;
if (it.value().ptr != nullptr) {
++cs.num_mem_cached;
@ -190,6 +192,7 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
}
}
cs.max_key = cur_key;
it.Unlock();
}
}
if (total_sz > 0) {
@ -199,6 +202,7 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
cs.average_cache_sz = 1;
}
}
tree_->Unlock();
return cs;
}

@ -58,7 +58,7 @@ Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const Te
if (sent_using_local_bypass) {
MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data";
// Allocate shared memory from the server
auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), sz_);
auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), cc->GetClientId(), sz_);
RETURN_IF_NOT_OK(cc->PushRequest(mem_rq));
RETURN_IF_NOT_OK(mem_rq->Wait());
addr_ = mem_rq->GetAddr();
@ -305,6 +305,15 @@ Status GetStatRequest::PostReply() {
return Status::OK();
}
Status GetCacheStateRequest::PostReply() {
try {
cache_service_state_ = std::stoi(reply_.result());
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
return Status::OK();
}
Status ListSessionsRequest::PostReply() {
auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data());
auto session_vector = msg->sessions();
@ -333,5 +342,13 @@ Status ServerStopRequest::PostReply() {
return Status::OK();
}
BatchCacheRowsRequest::BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele)
: BaseRequest(RequestType::kBatchCacheRows) {
rq_.set_connection_id(cc->server_connection_id_);
rq_.set_client_id(cc->client_id_);
rq_.add_buf_data(cc->cookie());
rq_.add_buf_data(std::to_string(addr));
rq_.add_buf_data(std::to_string(num_ele));
}
} // namespace dataset
} // namespace mindspore

@ -78,6 +78,9 @@ class BaseRequest {
kListSessions = 16,
kConnectReset = 17,
kInternalFetchRow = 18,
kBatchCacheRows = 19,
kInternalCacheRow = 20,
kGetCacheState = 21,
// Add new request before it.
kRequestUnknown = 32767
};
@ -133,10 +136,11 @@ class BaseRequest {
class FreeSharedBlockRequest : public BaseRequest {
public:
friend class CacheServer;
explicit FreeSharedBlockRequest(connection_id_type connection_id, int64_t addr)
explicit FreeSharedBlockRequest(connection_id_type connection_id, int32_t client_id, int64_t addr)
: BaseRequest(RequestType::kFreeSharedBlock) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(std::to_string(addr));
rq_.set_client_id(client_id);
}
~FreeSharedBlockRequest() override = default;
};
@ -178,7 +182,7 @@ class CacheRowRequest : public BaseRequest {
/// the shared memory by sending another request. The following function will generate a suitable
/// request for the CacheClient to send.
std::shared_ptr<FreeSharedBlockRequest> GenerateFreeBlockRequest() {
return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), addr_);
return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), rq_.client_id(), addr_);
}
private:
@ -271,6 +275,24 @@ class GetStatRequest : public BaseRequest {
CacheServiceStat stat_{};
};
/// \brief Get the state of a cache service
class GetCacheStateRequest : public BaseRequest {
public:
friend class CacheServer;
explicit GetCacheStateRequest(connection_id_type connection_id)
: BaseRequest(RequestType::kGetCacheState), cache_service_state_(0) {
rq_.set_connection_id(connection_id);
}
~GetCacheStateRequest() override = default;
Status PostReply() override;
auto GetState() const { return cache_service_state_; }
private:
int8_t cache_service_state_;
};
/// \brief Request to cache a schema
class CacheSchemaRequest : public BaseRequest {
public:
@ -367,10 +389,11 @@ class ListSessionsRequest : public BaseRequest {
class AllocateSharedBlockRequest : public BaseRequest {
public:
friend class CacheServer;
explicit AllocateSharedBlockRequest(connection_id_type connection_id, size_t requestedSz)
explicit AllocateSharedBlockRequest(connection_id_type connection_id, int32_t client_id, size_t requestedSz)
: BaseRequest(RequestType::kAllocateSharedBlock) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(std::to_string(requestedSz));
rq_.set_client_id(client_id);
}
~AllocateSharedBlockRequest() override = default;
@ -420,6 +443,13 @@ class ConnectResetRequest : public BaseRequest {
return Status::OK();
}
};
class BatchCacheRowsRequest : public BaseRequest {
public:
friend class CacheServer;
explicit BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele);
~BatchCacheRowsRequest() override = default;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_

File diff suppressed because it is too large Load Diff

@ -25,12 +25,14 @@
#include <chrono>
#include <iostream>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include <map>
#include <set>
#include <thread>
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/engine/cache/cache_hw.h"
#include "minddata/dataset/engine/cache/cache_numa.h"
#include "minddata/dataset/engine/cache/cache_service.h"
@ -196,15 +198,31 @@ class CacheServer : public Service {
/// \brief Check if we bind threads to numa cores
bool IsNumaAffinityOn() const { return numa_affinity_; }
/// \brief Internal function to do row batch fetch
/// \param rq Request
/// \param reply Reply
/// \return Status object
Status BatchFetchRows(CacheRequest *rq, CacheReply *reply);
/// \brief Return the memory cap ratio
float GetMemoryCapRatio() const { return memory_cap_ratio_; }
/// \brief How a request is handled.
/// \note that it can be process immediately by a grpc thread or routed to a server thread
/// which is pinned to some numa node core.
void ProcessRequest(CacheServerRequest *cache_req);
void GlobalShutdown();
/// \brief This returns where we attach to the shared memory.
/// Some gRPC requests will ask for a shared memory block, and
/// we can't return the absolute address as this makes no sense
/// in the client. So instead we will return an address relative
/// to the base address of the shared memory where we attach to.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return shm_->SharedMemoryBaseAddr(); }
/// \brief Return the public key of the shared memory.
int32_t GetKey() const { return shm_->GetKey(); }
Status AllocateSharedMemory(int32_t client_id, size_t sz, void **p);
void DeallocateSharedMemory(int32_t client_id, void *p);
private:
static std::once_flag init_instance_flag_;
static CacheServer *instance_;
@ -228,6 +246,7 @@ class CacheServer : public Service {
std::map<worker_id_t, Task *> numa_tasks_;
bool numa_affinity_;
std::vector<int32_t> shutdown_qIDs_;
std::unique_ptr<CachedSharedMemory> shm_;
/// \brief Constructor
/// \param spill_path Top directory for spilling buffers to.
@ -315,7 +334,7 @@ class CacheServer : public Service {
/// \brief A proper shutdown of the server
/// \return Status object
Status GlobalShutdown(CacheServerRequest *);
Status AcknowledgeShutdown(CacheServerRequest *cache_req);
/// \brief Find keys that will be cache miss
/// \return Status object
@ -332,12 +351,19 @@ class CacheServer : public Service {
/// \brief Connect request by a pipeline
Status ConnectReset(CacheRequest *rq);
/// \brief Internal function to do row batch fetch
/// \param rq Request
/// \param reply Reply
/// \return Status object
Status BatchFetchRows(CacheRequest *rq, CacheReply *reply);
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
/// \param[in] v A vector of row id.
/// \param[out] out A contiguous memory buffer that holds the requested rows.
/// \return Status object
Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out);
Status BatchCacheRows(CacheRequest *rq);
};
} // namespace dataset
} // namespace mindspore

@ -68,10 +68,11 @@ Status CacheService::DoServiceStop() {
Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
if (st_ == CacheServiceState::kFetchPhase) {
if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) {
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " +
std::to_string(static_cast<int>(st_.load())));
}
if (st_ == CacheServiceState::kNoLocking) {
// We ignore write this request once we turn off locking on the B+ tree. So we will just
@ -119,6 +120,16 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
if (rc == Status(StatusCode::kDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else {
if (HasBuildPhase()) {
// For cache service that has a build phase, record the error in the state
// so other clients can be aware of the new state. There is nothing one can
// do to resume other than to drop the cache.
if (rc.IsNoSpace()) {
st_ = CacheServiceState::kNoSpace;
} else if (rc.IsOutofMemory()) {
st_ = CacheServiceState::kOutOfMemory;
}
}
RETURN_IF_NOT_OK(rc);
}
return Status::OK();
@ -130,10 +141,11 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
if (st_ == CacheServiceState::kFetchPhase) {
if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) {
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " +
std::to_string(static_cast<int>(st_.load())));
}
if (st_ == CacheServiceState::kNoLocking) {
// We ignore write this request once we turn off locking on the B+ tree. So we will just
@ -161,6 +173,16 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
if (rc == Status(StatusCode::kDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else {
if (HasBuildPhase()) {
// For cache service that has a build phase, record the error in the state
// so other clients can be aware of the new state. There is nothing one can
// do to resume other than to drop the cache.
if (rc.IsNoSpace()) {
st_ = CacheServiceState::kNoSpace;
} else if (rc.IsOutofMemory()) {
st_ = CacheServiceState::kOutOfMemory;
}
}
RETURN_IF_NOT_OK(rc);
}
return Status::OK();
@ -202,16 +224,17 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(out);
out->stat_ = cp_->GetStat();
out->state_ = static_cast<ServiceStat::state_type>(st_);
out->state_ = static_cast<ServiceStat::state_type>(st_.load());
return Status::OK();
}
Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v,
const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) {
SharedLock rw(&rw_lock_);
if (st_ == CacheServiceState::kBuildPhase) {
if (HasBuildPhase() && st_ != CacheServiceState::kFetchPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " +
std::to_string(static_cast<int>(st_.load())));
}
std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v;
datalocator_v.reserve(v.size());
@ -271,7 +294,8 @@ Status CacheService::FetchSchema(std::string *out) const {
SharedLock rw(&rw_lock_);
if (st_ == CacheServiceState::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " +
std::to_string(static_cast<int>(st_.load())));
}
RETURN_UNEXPECTED_IF_NULL(out);
// We are going to use std::string to allocate and hold the result which will be eventually
@ -292,6 +316,7 @@ Status CacheService::BuildPhaseDone() {
UniqueLock rw(&rw_lock_);
st_ = CacheServiceState::kFetchPhase;
cp_->SetLocking(false);
MS_LOG(WARNING) << "Locking mode is switched off.";
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase");

@ -91,6 +91,8 @@ class CacheService : public Service {
/// \param[in/out] A pointer to a pre-allocated ServiceStat structure
/// \return Status Object
Status GetStat(ServiceStat *);
/// \brief Return the current state
CacheServiceState GetState() const { return st_.load(); }
/// \brief Cache schema
/// \param buf A Google Flatbuffer that contains the schema
/// \param len size of the buffer
@ -131,7 +133,7 @@ class CacheService : public Service {
bool generate_id_;
std::string cookie_;
std::atomic<int32_t> num_clients_;
CacheServiceState st_;
std::atomic<CacheServiceState> st_;
std::string schema_;
std::shared_ptr<NumaMemoryPool> numa_pool_;
// We also cache the result from calling FindKeysMiss because it is expensive. Besides user make

@ -427,6 +427,7 @@ Status CachePerfRun::Run() {
}
// Now we create the children knowing all two sets of message queues are constructed.
auto start_tick = std::chrono::steady_clock::now();
for (auto i = 0; i < num_pipelines_; ++i) {
auto pid = fork();
if (pid == 0) {
@ -502,6 +503,10 @@ Status CachePerfRun::Run() {
// Wait until all pipelines finish the first epoch.
RETURN_IF_NOT_OK(pipeline_wp_.Wait());
auto end_tick = std::chrono::steady_clock::now();
int64_t elapse_time = std::chrono::duration_cast<std::chrono::seconds>(end_tick - start_tick).count();
std::cout << "Epoch one (build phase) elapsed time " << elapse_time << " seconds" << std::endl;
std::cout << "Epoch one (build phase) per pipeline per worker summary. Buffer size = " << cfg_.rows_per_buffer()
<< std::endl;
@ -543,6 +548,7 @@ Status CachePerfRun::Run() {
epoch_sync_cnt_ = 0;
pipeline_wp_.Clear();
epoch_results_.clear();
start_tick = std::chrono::steady_clock::now();
// Signal each pipeline to start
for (auto msg_qid : msg_send_lists_) {
CachePerfMsg msg;
@ -551,6 +557,9 @@ Status CachePerfRun::Run() {
}
// Wait for the child to finish
RETURN_IF_NOT_OK(pipeline_wp_.Wait());
end_tick = std::chrono::steady_clock::now();
elapse_time = std::chrono::duration_cast<std::chrono::seconds>(end_tick - start_tick).count();
std::cout << "Epoch " << epoch_num << " elapsed time " << elapse_time << " seconds" << std::endl;
std::cout << "Epoch " << epoch_num
<< " (read phase) per pipeline per worker summary. Buffer size = " << cc_->GetPrefetchSize() << std::endl;
PrintEpochSummary();

@ -238,6 +238,9 @@ Status CachePipelineRun::RunFirstEpoch() {
RETURN_IF_NOT_OK(pTask->Join(Task::WaitFlag::kBlocking));
}
// Final flush
cc_->FlushAsyncWriteBuffer();
// Send a message saying epoch one done for this pipeline.
EpochDone proto;
proto.set_pipeline(my_pipeline_);
@ -291,7 +294,7 @@ Status CachePipelineRun::WriterWorkerEntry(int32_t worker_id) {
buffer->set_tensor_table(std::move(tensor_table));
// Measure the time to call WriteBuffer
auto start_tick = std::chrono::steady_clock::now();
rc = cc_->WriteBuffer(std::move(buffer));
rc = cc_->AsyncWriteBuffer(std::move(buffer));
auto end_tick = std::chrono::steady_clock::now();
if (rc.IsError()) {
if (rc.IsOutofMemory() || rc.IsNoSpace()) {

@ -122,6 +122,17 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
if (db_ptr->eoe()) {
// Ignore it.
MS_LOG(DEBUG) << "Ignore eoe";
// However we need to flush any left over from the async write buffer. But any error
// we are getting will just to stop caching but the pipeline will continue
Status rc;
if ((rc = cache_client_->FlushAsyncWriteBuffer()).IsError()) {
cache_missing_rows_ = false;
if (rc.IsOutofMemory() || rc.IsNoSpace()) {
cache_client_->ServerRunningOutOfResources();
} else {
MS_LOG(INFO) << "Async row flushing not successful: " << rc.ToString();
}
}
} else {
while (db_ptr->NumRows() > 0) {
TensorRow row;
@ -143,6 +154,9 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
rc = rq->AsyncSendCacheRequest(cache_client_, row);
if (rc.IsOk()) {
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
} else if (rc.IsOutofMemory() || rc.IsNoSpace()) {
cache_missing_rows_ = false;
cache_client_->ServerRunningOutOfResources();
}
}
}
@ -309,17 +323,25 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha
if (st_.compare_exchange_strong(expected, State::kDirty)) {
// We will do a deep copy but write directly into CacheRequest protobuf or shared memory
Status rc;
cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get());
rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);
if (rc.IsOk()) {
// Send the request async. The cleaner will check the return code.
rc = cc->PushRequest(cleaner_copy_);
rc = cc->AsyncWriteRow(row);
if (rc.get_code() == StatusCode::kNotImplementedYet) {
cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get());
rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);
if (rc.IsOk()) {
// Send the request async. The cleaner will check the return code.
rc = cc->PushRequest(cleaner_copy_);
}
} else if (rc.IsOk()) {
// Set the state to clean even though it still sits in the cache client async buffer.
// The cleaner will then ignore it once the state is clean.
st_ = State::kClean;
}
if (rc.IsError()) {
// Clean up the shared pointer and reset the state back to empty
cleaner_copy_.reset();
st_ = State::kEmpty;
}
return rc;
}
return Status::OK();
}

@ -109,7 +109,14 @@ Status CacheOp::CacheAllRows(int32_t worker_id) {
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
while (!db_ptr->eof()) {
if (!db_ptr->eoe()) {
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
Status rc;
// Do the Async write if we attach to the shared memory.
rc = cache_client_->AsyncWriteBuffer(std::move(db_ptr));
if (rc.get_code() == StatusCode::kNotImplementedYet) {
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
} else if (rc.IsError()) {
return rc;
}
} else {
// In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up
// as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the
@ -139,21 +146,41 @@ Status CacheOp::WaitForCachingAllRows() {
RETURN_IF_NOT_OK(rows_cache_done_.Wait());
// Move from build phase to fetch phase if we are the one to fill the cache
if (phase_ == Phase::kBuildPhase) {
RETURN_IF_NOT_OK(cache_client_->FlushAsyncWriteBuffer()); // One more flush
RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone());
// Move to the next phase
phase_ = Phase::kFetchPhase;
}
// Get statistics from the server, and if we are not the one to create the cache,
// If we are not the one to create the cache,
// wait until the state changed from build phase to fetch base.
CacheServiceStat stat{};
bool BuildPhaseDone = true;
do {
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheServiceState::kFetchPhase);
if (!BuildPhaseDone) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
int8_t out;
RETURN_IF_NOT_OK(cache_client_->GetState(&out));
auto state = static_cast<CacheServiceState>(out);
switch (state) {
case CacheServiceState::kBuildPhase:
// Do nothing. Continue to wait.
BuildPhaseDone = false;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
break;
case CacheServiceState::kFetchPhase:
BuildPhaseDone = true;
break;
case CacheServiceState::kOutOfMemory:
return Status(StatusCode::kOutOfMemory, "Cache server is running out of memory");
case CacheServiceState::kNoSpace:
return Status(StatusCode::kNoSpace, "Cache server is running of out spill storage");
case CacheServiceState::kNone:
case CacheServiceState::kError:
default:
RETURN_STATUS_UNEXPECTED("Unexpected state: " + std::to_string(out));
}
} while (!BuildPhaseDone);
// Get statistics from the server, and if we are not the one to create the cache,
// wait until the state changed from build phase to fetch base.
CacheServiceStat stat{};
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
const row_id_type min_key = stat.min_row_id;
const row_id_type max_key = stat.max_row_id;
num_rows_ = max_key - min_key + 1;

@ -148,6 +148,12 @@ class BPlusTree {
acquire_lock_ = on_off;
}
void LockShared() { rw_lock_.LockShared(); }
void LockExclusive() { rw_lock_.LockExclusive(); }
void Unlock() { rw_lock_.Unlock(); }
private:
// Abstract class of a node (leaf or inner)
class BaseNode {
@ -409,6 +415,21 @@ class BPlusTree {
bool operator==(const Iterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); }
bool operator!=(const Iterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); }
void LockShared() {
cur_->rw_lock_.LockShared();
locked_ = true;
}
void LockExclusive() {
cur_->rw_lock_.LockExclusive();
locked_ = true;
}
void Unlock() {
cur_->rw_lock_.Unlock();
locked_ = false;
}
private:
typename BPlusTree::LeafNode *cur_;
slot_type slot_;
@ -458,6 +479,21 @@ class BPlusTree {
bool operator==(const ConstIterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); }
bool operator!=(const ConstIterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); }
void LockShared() {
cur_->rw_lock_.LockShared();
locked_ = true;
}
void LockExclusive() {
cur_->rw_lock_.LockExclusive();
locked_ = true;
}
void Unlock() {
cur_->rw_lock_.Unlock();
locked_ = false;
}
private:
const typename BPlusTree::LeafNode *cur_;
slot_type slot_;

Loading…
Cancel
Save