diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc index e7096db469..5542f61189 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc @@ -14,34 +14,96 @@ * limitations under the License. */ #include "minddata/dataset/engine/cache/cache_arena.h" +#include "minddata/dataset/engine/cache/cache_server.h" #include "minddata/dataset/util/path.h" namespace mindspore { namespace dataset { -CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) : val_in_GB_(val_in_GB), port_(port) { +CachedSharedMemory::CachedSharedMemory(int32_t port, size_t val_in_GB) + : shared_memory_sz_in_gb_(val_in_GB), port_(port), num_numa_nodes_(-1), sub_pool_sz_(-1) { // We create the shared memory and we will destroy it. All other client just detach only. shm_.RemoveResourcesOnExit(); } -CachedSharedMemoryArena::~CachedSharedMemoryArena() {} +CachedSharedMemory::~CachedSharedMemory() = default; -Status CachedSharedMemoryArena::CreateArena(std::unique_ptr *out, int32_t port, - size_t val_in_GB) { - RETURN_UNEXPECTED_IF_NULL(out); - auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB); - if (ba == nullptr) { - return Status(StatusCode::kOutOfMemory); - } - // Transfer the ownership of this pointer. Any future error in the processing we will have - // the destructor of *out to deal. - (*out).reset(ba); +Status CachedSharedMemory::Init() { + CacheServer &cs = CacheServer::GetInstance(); + num_numa_nodes_ = cs.GetNumaNodeCount(); // Generate the ftok using a combination of port. SharedMemory::shm_key_t shm_key; - RETURN_IF_NOT_OK(PortToFtok(port, &shm_key)); - ba->shm_.SetPublicKey(shm_key); + RETURN_IF_NOT_OK(PortToFtok(port_, &shm_key)); + shm_.SetPublicKey(shm_key); // Value is in GB. Convert into bytes. - int64_t sz = val_in_GB * 1073741824L; - RETURN_IF_NOT_OK(ba->shm_.Create(sz)); - ba->impl_ = std::make_unique(ba->shm_.SharedMemoryBaseAddr(), sz); + int64_t shm_mem_sz = shared_memory_sz_in_gb_ * 1073741824L; + RETURN_IF_NOT_OK(shm_.Create(shm_mem_sz)); + MS_LOG(INFO) << "Creation of shared memory successful. Shared memory key " << shm_.GetKey(); + // Interleave the memory. + cs.GetHWControl()->InterleaveMemory(shm_.SharedMemoryBaseAddr(), shm_mem_sz); + // We will create a number of sub pool out of shared memory to reduce latch contention + int32_t num_of_pools = num_numa_nodes_; + if (num_numa_nodes_ == 1) { + num_of_pools = shared_memory_sz_in_gb_ * 2; + } + sub_pool_sz_ = shm_mem_sz / num_of_pools; + // If each subpool is too small, readjust the number of pools + constexpr int64 min_subpool_sz = 512 * 1048576L; + if (sub_pool_sz_ < min_subpool_sz) { + sub_pool_sz_ = min_subpool_sz; + num_of_pools = shm_mem_sz / min_subpool_sz; + } + shm_pool_.reserve(num_of_pools); + for (auto i = 0; i < num_of_pools; ++i) { + void *ptr = static_cast(shm_.SharedMemoryBaseAddr()) + i * sub_pool_sz_; + shm_pool_.push_back(std::make_unique(ptr, sub_pool_sz_)); + } + mux_ = std::make_unique(num_of_pools); + return Status::OK(); +} + +Status CachedSharedMemory::CreateArena(std::unique_ptr *out, int32_t port, size_t val_in_GB) { + RETURN_UNEXPECTED_IF_NULL(out); + auto mem_pool = std::unique_ptr(new CachedSharedMemory(port, val_in_GB)); + RETURN_IF_NOT_OK(mem_pool->Init()); + *out = std::move(mem_pool); + return Status::OK(); +} + +Status CachedSharedMemory::AllocateSharedMemory(int32_t client_id, size_t sz, void **p) { + Status rc; + RETURN_UNEXPECTED_IF_NULL(p); + auto begin_slot = client_id % shm_pool_.size(); + auto slot = begin_slot; + do { + std::unique_lock lock(mux_[slot]); + rc = shm_pool_[slot]->Allocate(sz, p); + if (rc.IsOutofMemory()) { + slot = (slot + 1) % shm_pool_.size(); + } + } while (rc.IsError() && slot != begin_slot); + if (rc.IsError()) { + return rc; + } return Status::OK(); } + +void CachedSharedMemory::DeallocateSharedMemory(int32_t client_id, void *p) { + auto begin_slot = client_id % shm_pool_.size(); + auto slot = begin_slot; + auto start_addr = static_cast(SharedMemoryBaseAddr()); + bool found = false; + do { + auto ptr = start_addr + slot * sub_pool_sz_; + if (ptr <= p && p < (ptr + sub_pool_sz_)) { + std::unique_lock lock(mux_[slot]); + shm_pool_[slot]->Deallocate(p); + found = true; + break; + } else { + slot = (slot + 1) % shm_pool_.size(); + } + } while (slot != begin_slot); + if (!found) { + MS_LOG(ERROR) << "Programming error. Can't find the arena the pointer " << p << " comes from"; + } +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h index 6f61960e69..743fe475f1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h @@ -18,25 +18,29 @@ #include #include +#include #include +#include #include "minddata/dataset/util/arena.h" #include "minddata/dataset/engine/cache/cache_common.h" #include "minddata/dataset/engine/cache/cache_ipc.h" namespace mindspore { namespace dataset { -/// This is a derived class of Arena but resides in shared memory -class CachedSharedMemoryArena : public MemoryPool { +/// This is like a CircularPool but each arena is in shared memory and +/// possibly bind to a numa socket. +class CachedSharedMemory { public: // Disable copy and assignment constructor - CachedSharedMemoryArena(const CachedSharedMemoryArena &) = delete; - CachedSharedMemoryArena &operator=(const CachedSharedMemoryArena &) = delete; - ~CachedSharedMemoryArena() override; + CachedSharedMemory(const CachedSharedMemory &) = delete; + CachedSharedMemory &operator=(const CachedSharedMemory &) = delete; + ~CachedSharedMemory(); + /// \brief Create an Arena in shared memory /// \param[out] p_ba Pointer to a unique_ptr /// \param shmkey Shared memory key /// \param val_in_GB size of shared memory in gigabyte /// \return Status object - static Status CreateArena(std::unique_ptr *out, int32_t port, size_t val_in_GB); + static Status CreateArena(std::unique_ptr *out, int32_t port, size_t val_in_GB); /// \brief This returns where we attach to the shared memory. /// Some gRPC requests will ask for a shared memory block, and @@ -44,45 +48,29 @@ class CachedSharedMemoryArena : public MemoryPool { /// in the client. So instead we will return an address relative /// to the base address of the shared memory where we attach to. /// \return Base address of the shared memory. - const void *SharedMemoryBaseAddr() const { return impl_->get_base_addr(); } - - /// As a derived class of MemoryPool, we have to implement the following - /// But we simply transfer the call to the implementation class - Status Allocate(size_t size, void **pVoid) override { - std::unique_lock lock(mux_); - return impl_->Allocate(size, pVoid); - } - Status Reallocate(void **pVoid, size_t old_sz, size_t new_sz) override { - std::unique_lock lock(mux_); - return impl_->Reallocate(pVoid, old_sz, new_sz); - } - void Deallocate(void *pVoid) override { - std::unique_lock lock(mux_); - impl_->Deallocate(pVoid); - } - uint64_t get_max_size() const override { return impl_->get_max_size(); } - int PercentFree() const override { - std::unique_lock lock(mux_); - return impl_->PercentFree(); - } - - /// \brief Dump the memory allocation block. - friend std::ostream &operator<<(std::ostream &os, const CachedSharedMemoryArena &s) { - os << *(s.impl_); - return os; - } + const void *SharedMemoryBaseAddr() const { return shm_.SharedMemoryBaseAddr(); } + void *SharedMemoryBaseAddr() { return shm_.SharedMemoryBaseAddr(); } /// \brief Get the shared memory key of the shared memory SharedMemory::shm_key_t GetKey() const { return shm_.GetKey(); } + /// \brief Allocate shared memory for a given pipeline + Status AllocateSharedMemory(int32_t client_id, size_t sz, void **p); + + /// \brief Deallocate shared memory for a given pipeline + void DeallocateSharedMemory(int32_t client_id, void *p); + private: - mutable std::mutex mux_; - int32_t val_in_GB_; + int32_t shared_memory_sz_in_gb_; int32_t port_; SharedMemory shm_; - std::unique_ptr impl_; + std::vector> shm_pool_; + std::unique_ptr mux_; + int32_t num_numa_nodes_; + int64_t sub_pool_sz_; /// Private constructor. Not to be called directly. - CachedSharedMemoryArena(int32_t port, size_t val_in_GB); + CachedSharedMemory(int32_t port, size_t val_in_GB); + Status Init(); }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc index 6c970e4b83..9b08091fc1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc @@ -19,6 +19,7 @@ #include "minddata/dataset/engine/cache/cache_request.h" #include "minddata/dataset/engine/cache/cache_fbb.h" #include "minddata/dataset/util/bit.h" +#include "minddata/dataset/util/task_manager.h" namespace mindspore { namespace dataset { @@ -71,6 +72,10 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool CacheClient::~CacheClient() { cache_miss_keys_wp_.Set(); + // Manually release the async buffer because we need the comm layer. + if (async_buffer_stream_) { + async_buffer_stream_->ReleaseBuffer(); + } if (client_id_ != -1) { try { // Send a message to the server, saying I am done. @@ -132,6 +137,42 @@ Status CacheClient::WriteBuffer(std::unique_ptr &&in) const { return Status::OK(); } +Status CacheClient::AsyncWriteRow(const TensorRow &row) { + if (async_buffer_stream_ == nullptr) { + return Status(StatusCode::kNotImplementedYet); + } + RETURN_IF_NOT_OK(async_buffer_stream_->AsyncWrite(row)); + return Status::OK(); +} + +Status CacheClient::AsyncWriteBuffer(std::unique_ptr &&in) { + if (async_buffer_stream_ == nullptr) { + return Status(StatusCode::kNotImplementedYet); + } else { + Status rc; + std::unique_ptr tensor_table = std::make_unique(); + auto num_rows = in->NumRows(); + if (num_rows > 0) { + for (auto i = 0; i < num_rows; ++i) { + TensorRow row; + RETURN_IF_NOT_OK(in->PopRow(&row)); + rc = AsyncWriteRow(row); + if (rc.get_code() == StatusCode::kNotImplementedYet) { + tensor_table->push_back(row); + } else if (rc.IsError()) { + return rc; + } + } + } + // If not all of them can be sent async, return what's left back to the caller. + if (!tensor_table->empty()) { + in->set_tensor_table(std::move(tensor_table)); + return Status(StatusCode::kNotImplementedYet); + } + } + return Status::OK(); +} + Status CacheClient::GetRows(const std::vector &row_id, TensorTable *out) const { RETURN_UNEXPECTED_IF_NULL(out); auto rq = std::make_shared(this, row_id); @@ -141,7 +182,7 @@ Status CacheClient::GetRows(const std::vector &row_id, TensorTable Status rc = rq->RestoreRows(out, comm_->SharedMemoryBaseAddr(), &mem_addr); // Free the memory by sending a request back to the server. if (mem_addr != -1) { - auto mfree_req = std::make_shared(server_connection_id_, mem_addr); + auto mfree_req = std::make_shared(server_connection_id_, client_id_, mem_addr); Status rc2 = PushRequest(mfree_req); // But we won't wait for the result for the sake of performance. if (rc.IsOk() && rc2.IsError()) { @@ -211,6 +252,10 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { if (success) { // Attach to shared memory for local client RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_)); + if (local_bypass_) { + async_buffer_stream_ = std::make_shared(); + RETURN_IF_NOT_OK(async_buffer_stream_->Init(this)); + } } // We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the // CacheOp to bypass the build phase. @@ -240,6 +285,17 @@ Status CacheClient::GetStat(CacheServiceStat *stat) { return Status::OK(); } +Status CacheClient::GetState(int8_t *out) { + SharedLock lck(&mux_); + RETURN_UNEXPECTED_IF_NULL(out); + CHECK_FAIL_RETURN_UNEXPECTED(server_connection_id_ != 0, "GetState called but the cache is not in use yet."); + auto rq = std::make_shared(server_connection_id_); + RETURN_IF_NOT_OK(PushRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); + *out = rq->GetState(); + return Status::OK(); +} + Status CacheClient::CacheSchema(const std::unordered_map &map) { SharedLock lck(&mux_); auto rq = std::make_shared(server_connection_id_); @@ -334,5 +390,181 @@ bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) { return it != gap_.end(); } } + +CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0), next_addr_(0) {} + +CacheClient::AsyncBufferStream::~AsyncBufferStream() { + (void)vg_.ServiceStop(); + writer_wp_.Set(); + (void)ReleaseBuffer(); +} + +Status CacheClient::AsyncBufferStream::ReleaseBuffer() { + if (offset_addr_ != -1) { + auto mfree_req = + std::make_shared(cc_->server_connection_id_, cc_->GetClientId(), offset_addr_); + offset_addr_ = -1; + RETURN_IF_NOT_OK(cc_->PushRequest(mfree_req)); + RETURN_IF_NOT_OK(mfree_req->Wait()); + } + return Status::OK(); +} + +Status CacheClient::AsyncBufferStream::Init(CacheClient *cc) { + cc_ = cc; + // Allocate shared memory from the server + auto mem_rq = std::make_shared(cc_->server_connection_id_, cc_->GetClientId(), + kAsyncBufferSize * kNumAsyncBuffer); + RETURN_IF_NOT_OK(cc->PushRequest(mem_rq)); + RETURN_IF_NOT_OK(mem_rq->Wait()); + offset_addr_ = mem_rq->GetAddr(); + // Now we need to add that to the base address of where we attach. + auto base = cc->SharedMemoryBaseAddr(); + auto start = reinterpret_cast(base) + offset_addr_; + for (auto i = 0; i < kNumAsyncBuffer; ++i) { + // We only need to set the pointer during init. Other fields will be set dynamically. + buf_arr_[i].buffer_ = reinterpret_cast(start + i * kAsyncBufferSize); + } + buf_arr_[0].begin_addr_ = 0; + buf_arr_[0].end_addr_ = 0; + buf_arr_[0].bytes_avail_ = kAsyncBufferSize; + buf_arr_[0].num_ele_ = 0; + RETURN_IF_NOT_OK(vg_.ServiceStart()); + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async flush", std::bind(&CacheClient::AsyncBufferStream::AsyncFlush, this))); + return Status::OK(); +} + +Status CacheClient::AsyncBufferStream::AsyncWrite(const TensorRow &row) { + std::vector v; + v.reserve(row.size() + 1); + std::shared_ptr fbb; + RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb)); + int64_t sz = fbb->GetSize(); + v.emplace_back(fbb->GetBufferPointer(), sz); + for (const auto &ts : row) { + sz += ts->SizeInBytes(); + v.emplace_back(ts->GetBuffer(), ts->SizeInBytes()); + } + // If the size is too big, tell the user to send it directly. + if (sz > kAsyncBufferSize) { + return Status(StatusCode::kNotImplementedYet); + } + // Find out where we are going to write in the (logical) buffer stream without acquiring the lock + // but only use the atomic variable. + auto write_addr = next_addr_.fetch_add(sz); + Status rc; + do { + SharedLock lock(&mux_); + // Check error from the server side while we have the lock; + RETURN_IF_NOT_OK(flush_rc_); + AsyncWriter *asyncWriter = &buf_arr_[cur_]; + rc = asyncWriter->Write(write_addr, sz, v); + if (rc.get_code() == StatusCode::kNoSpace) { + // If no space, wake up the async flush thread + writer_wp_.Clear(); + flush_wp_.Set(); + // Let go of the lock before we wait. + lock.Unlock(); + // Wait for the next window + RETURN_IF_NOT_OK(writer_wp_.Wait()); + } + } while (rc.get_code() == StatusCode::kNoSpace); + return rc; +} + +Status CacheClient::AsyncBufferStream::SyncFlush(bool blocking) { + bool retry = false; + do { + UniqueLock lock(&mux_); + flush_wp_.Clear(); + auto *asyncWriter = &buf_arr_[cur_]; + retry = false; + // Because the clients are copying async, we need to wait until all of them have written. + if (kAsyncBufferSize - (asyncWriter->end_addr_ - asyncWriter->begin_addr_) == asyncWriter->bytes_avail_) { + if (asyncWriter->num_ele_) { + asyncWriter->rq.reset( + new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_)); + flush_rc_ = cc_->PushRequest(asyncWriter->rq); + if (flush_rc_.IsOk()) { + // If we are asked to wait, say this is the final flush, just wait for its completion. + if (blocking) { + flush_rc_ = asyncWriter->rq->Wait(); + asyncWriter->rq.reset(); + } + // Prepare for the next buffer which will start from the end addr of the previous buffer. + int64_t previous_end_addr = asyncWriter->end_addr_; + cur_ = (cur_ + 1) % kNumAsyncBuffer; + asyncWriter = &buf_arr_[cur_]; + // Update the cur_ while we have the lock. + // Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content + // Also we can also pick up any error from previous flush. + if (asyncWriter->rq) { + // Save the result into a common area, so worker can see it and quit. + flush_rc_ = asyncWriter->rq->Wait(); + asyncWriter->rq.reset(); + } + asyncWriter->bytes_avail_ = kAsyncBufferSize; + asyncWriter->num_ele_ = 0; + asyncWriter->begin_addr_ = previous_end_addr; + asyncWriter->end_addr_ = previous_end_addr; + } + } + } else { + // Some clients are late and aren't done yet. Let go of the lock. + lock.Unlock(); + retry = true; + writer_wp_.Set(); + std::this_thread::yield(); + } + } while (retry); + // Wake up any writer that is waiting. + writer_wp_.Set(); + return flush_rc_; +} + +Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t write_addr, int64_t sz, + const std::vector &v) { + // Map our logical address to the real physical address in the buffer like where we start and + // where we end. + auto rel_write_addr = write_addr - begin_addr_; + auto rel_end_addr = rel_write_addr + sz; + // If not enough space, time to flush and swap. + if (rel_end_addr > kAsyncBufferSize) { + return Status(StatusCode::kNoSpace); + } + for (auto &p : v) { + auto write_sz = p.GetSize(); + WritableSlice dest(reinterpret_cast(buffer_) + rel_write_addr, write_sz); + RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, p)); + bytes_avail_ -= write_sz; + rel_write_addr += write_sz; + } + CHECK_FAIL_RETURN_UNEXPECTED(rel_write_addr == rel_end_addr, "Programming error"); + ++num_ele_; + // Update the end_addr if ours is better + int64_t new_end_addr = write_addr + sz; + int64_t expected = end_addr_; + while (expected < new_end_addr) { + if (!end_addr_.compare_exchange_weak(expected, new_end_addr)) { + expected = end_addr_; + } + } + CHECK_FAIL_RETURN_UNEXPECTED(end_addr_ >= new_end_addr, "Programming error"); + return Status::OK(); +} + +Status CacheClient::AsyncBufferStream::AsyncFlush() { + TaskManager::FindMe()->Post(); + Status rc; + do { + RETURN_IF_NOT_OK(flush_wp_.Wait()); + RETURN_IF_INTERRUPTED(); + rc = SyncFlush(); + // Other than resource error, all other error we quit. + } while (rc.IsOk() || rc.IsOutofMemory() || rc.IsNoSpace()); + // Make sure we wake up workers waiting for us. + writer_wp_.Set(); + return rc; +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h index 7f3a64938e..c0d9d50ce9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h @@ -38,6 +38,8 @@ #include "minddata/dataset/util/lock.h" #include "minddata/dataset/util/cond_var.h" #include "minddata/dataset/util/queue_map.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/wait_post.h" namespace mindspore { namespace dataset { @@ -50,6 +52,7 @@ class CacheClient { friend class CreateCacheRequest; friend class CacheRowRequest; friend class BatchFetchRequest; + friend class BatchCacheRowsRequest; /// \brief A builder to help creating a CacheClient object class Builder { @@ -180,6 +183,11 @@ class CacheClient { /// \return Status object Status GetStat(CacheServiceStat *); + /// \brief Get the state of a cache server + /// \param[in/out] Pointer to a int8_t + /// \return Status object + Status GetState(int8_t *); + /// \brief Cache the schema at the cache server /// \param map The unordered map of the schema /// \return Status object @@ -230,6 +238,7 @@ class CacheClient { int32_t GetPort() const { return port_; } int32_t GetNumConnections() const { return num_connections_; } int32_t GetPrefetchSize() const { return prefetch_size_; } + int32_t GetClientId() const { return client_id_; } /// MergeOp will notify us when the server can't cache any more rows. /// We will stop any attempt to fetch any rows that are most likely @@ -250,6 +259,20 @@ class CacheClient { return false; } + // Default size of the async write buffer + constexpr static int64_t kAsyncBufferSize = 16 * 1048576L; // 16M + constexpr static int32_t kNumAsyncBuffer = 2; + + /// Force a final flush to the cache server. Must be called when receving eoe. + Status FlushAsyncWriteBuffer() { + if (async_buffer_stream_) { + return async_buffer_stream_->SyncFlush(true); + } + return Status::OK(); + } + + Status AsyncWriteBuffer(std::unique_ptr &&in); + private: mutable RWLock mux_; uint64_t cache_mem_sz_; @@ -288,6 +311,62 @@ class CacheClient { std::set gap_; }; std::unique_ptr cache_miss_keys_; + + /// A data stream of back-to-back serialized tensor rows. + class AsyncBufferStream { + public: + AsyncBufferStream(); + ~AsyncBufferStream(); + + /// \brief Initialize an Ascyn write buffer + Status Init(CacheClient *cc); + + /// A worker will call the API AsyncWrite to put a TensorRow into the data stream. + /// A background thread will stream the data to the cache server. + /// The result of calling AsyncWrite is not immediate known or it can be the last + /// result of some previous flush. + /// \note Need to call SyncFlush to do the final flush. + Status AsyncWrite(const TensorRow &row); + Status SyncFlush(bool blocking = false); + + /// This maps a physical shared memory to the data stream. + class AsyncWriter { + public: + friend class AsyncBufferStream; + Status Write(int64_t start_addr, int64_t sz, const std::vector &v); + + private: + std::shared_ptr rq; + void *buffer_; + int32_t num_ele_; // How many tensor rows in this buffer + int64_t begin_addr_; // Start of logical address of the data stream + std::atomic end_addr_; // End of the logical address of the data stream + std::atomic bytes_avail_; // Number of bytes remain + }; + + /// \brief Release the shared memory during shutdown + /// /note but needs comm layer to be alive. + Status ReleaseBuffer(); + + private: + Status flush_rc_; + WaitPost writer_wp_; + WaitPost flush_wp_; + RWLock mux_; + TaskGroup vg_; + CacheClient *cc_; + int64_t offset_addr_; + AsyncWriter buf_arr_[kNumAsyncBuffer]; + int32_t cur_; + std::atomic next_addr_; + + /// \brief Entry point of the async flush thread. + Status AsyncFlush(); + }; + std::shared_ptr async_buffer_stream_; + + /// \brief Serialize a Tensor into the async buffer. + Status AsyncWriteRow(const TensorRow &row); }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h index 894f6e5714..d2f7287305 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h @@ -37,6 +37,10 @@ namespace dataset { /// For too small amount, we won't get any benefit using shared memory method because we need /// two rpc requests to use shared memory method. constexpr static int32_t kLocalByPassThreshold = 64 * 1024; +/// \brief Default size (in GB) of shared memory we are going to create +constexpr static int32_t kDefaultSharedMemorySize = 4; +/// \brief Memory Cap ratio used by the server +constexpr static float kDefaultMemoryCapRatio = 0.8; /// \brief A flag used by the BatchFetch request (client side) if it can support local bypass constexpr static uint32_t kLocalClientSupport = 1; /// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is @@ -46,7 +50,15 @@ constexpr static uint32_t kDataIsInSharedMemory = 2; constexpr static int32_t kSharedMessageSize = 2048; /// \brief State of CacheService at the server. -enum class CacheServiceState : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking }; +enum class CacheServiceState : int8_t { + kNone = 0, + kBuildPhase = 1, + kFetchPhase = 2, + kNoLocking = 3, + kOutOfMemory = 4, + kNoSpace = 5, + kError = 127 +}; /// \brief Convert a Status object into a protobuf /// \param rc[in] Status object diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc index 8b83b3fefa..f20cfbe586 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc @@ -23,8 +23,7 @@ namespace mindspore { namespace dataset { -CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb) - : port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb), shm_key_(-1) { +CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port) : port_(port) { // Setup a path for unix socket. unix_socket_ = PortToUnixSocketPath(port); // We can't generate the ftok key yet until the unix_socket_ is created @@ -70,14 +69,6 @@ Status CacheServerGreeterImpl::Run() { server_ = builder.BuildAndStart(); if (server_) { MS_LOG(INFO) << "Server listening on " << server_address; -#if CACHE_LOCAL_CLIENT - RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); - shm_key_ = shm_pool_->GetKey(); - MS_LOG(INFO) << "Creation of local socket and shared memory successful. Shared memory key " << shm_key_; - auto cs = CacheServer::GetInstance().GetHWControl(); - // This shared memory is a hot memory and we will interleave among all the numa nodes. - cs->InterleaveMemory(const_cast(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L); -#endif } else { std::string errMsg = "Fail to start server. "; if (port_tcpip != port_) { @@ -147,14 +138,18 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp // Now we pass the address of this instance to CacheServer's main loop. MS_LOG(DEBUG) << "Handle request " << *this; // We will distribute the request evenly (or randomly) over all the numa nodes. - // The exception is BatchFetch which we need to pre-process here. - if (type_ == BaseRequest::RequestType::kBatchFetchRows) { - rc_ = cs.BatchFetchRows(&rq_, &reply_); - if (!rc_.IsInterrupted()) { - Status2CacheReply(rc_, &reply_); - st_ = CacheServerRequest::STATE::FINISH; - responder_.Finish(reply_, grpc::Status::OK, this); - } else { + // The exception is BatchFetch and BatchCache which we need to pre-process here. + // Also some requests are urgent that we want to process them here too. + if (type_ == BaseRequest::RequestType::kBatchFetchRows || type_ == BaseRequest::RequestType::kBatchCacheRows || + type_ == BaseRequest::RequestType::kStopService || type_ == BaseRequest::RequestType::kAllocateSharedBlock || + type_ == BaseRequest::RequestType::kFreeSharedBlock) { + cs.ProcessRequest(this); + // For cache_admin --stop, ProcessRequest is just acknowledging we receive the request. Now + // we call the real function. + if (type_ == BaseRequest::RequestType::kStopService) { + cs.GlobalShutdown(); + return Status(StatusCode::kInterrupted); + } else if (rc_.IsInterrupted()) { return rc_; } } else { @@ -191,10 +186,12 @@ Status CacheServerGreeterImpl::MonitorUnixSocket() { // If the unix socket is recreated for whatever reason, this server instance will be stale and // no other process and communicate with us. In this case we need to shutdown ourselves. if (p.Exists()) { + auto &cs = CacheServer::GetInstance(); SharedMemory::shm_key_t key; RETURN_IF_NOT_OK(PortToFtok(port_, &key)); - if (key != shm_key_) { - std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key_) + + auto shm_key = cs.GetKey(); + if (key != shm_key) { + std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key) + ". New key " + std::to_string(key) + ". Shutting down server"; MS_LOG(ERROR) << errMsg; RETURN_STATUS_UNEXPECTED(errMsg); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h index d8bf2ed6fd..4f8daad1bc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h @@ -18,12 +18,14 @@ #include #include +#include #include #include #include #include "minddata/dataset/engine/cache/cache_common.h" -#include "minddata/dataset/engine/cache/cache_arena.h" +#include "minddata/dataset/engine/cache/cache_ipc.h" #include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/arena.h" #include "minddata/dataset/util/status.h" #include "minddata/dataset/util/task_manager.h" @@ -75,7 +77,7 @@ class CacheServerGreeterImpl final { friend class CacheServer; public: - explicit CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb); + explicit CacheServerGreeterImpl(int32_t port); virtual ~CacheServerGreeterImpl(); /// \brief Brings up gRPC server /// \return none @@ -83,24 +85,18 @@ class CacheServerGreeterImpl final { /// \brief Entry function to handle cache server request Status HandleRequest(int32_t worker_id); - /// Return the shared memory pool. - /// \return Return the shared memory pool - CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); } - /// \brief Montor the status of the unix socket in case it is gone. Status MonitorUnixSocket(); + /// \brief This shutdown down the comm layer void Shutdown(); private: int32_t port_; - size_t shm_pool_sz_in_gb_; std::string unix_socket_; CacheServerGreeter::AsyncService svc_; std::unique_ptr cq_; std::unique_ptr server_; - std::unique_ptr shm_pool_; - SharedMemory::shm_key_t shm_key_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc index b429c209a6..f377942021 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc @@ -209,6 +209,14 @@ void CacheServerHW::InterleaveMemory(void *ptr, size_t sz) { #endif } +void CacheServerHW::AssignToNode(numa_id_t numa_id, void *ptr, size_t sz) { +#ifdef NUMA_ENABLED + if (numa_enabled()) { + numa_tonode_memory(ptr, sz, numa_id); + } +#endif +} + bool CacheServerHW::numa_enabled() { #ifdef NUMA_ENABLED return (numa_available() != -1); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.h index 586cb5ad8c..a57205b77e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.h @@ -63,6 +63,9 @@ class CacheServerHW { /// \brief Interleave a given memory block. Used by shared memory only. static void InterleaveMemory(void *ptr, size_t sz); + /// \brief Assign a given memory block to a numa node. Used by shared memory only. + void AssignToNode(numa_id_t numa_id, void *ptr, size_t sz); + /// \brief Set default memory policy. static Status SetDefaultMemoryPolicy(CachePoolPolicy); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.cc index adb5f6450b..1b822e684b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.cc @@ -53,7 +53,7 @@ Status SharedMessage::Create() { Status SharedMessage::SendStatus(const Status &rc) { CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id"); - StatusMsgBuf msg{ + CacheMsgBuf msg{ 1, }; msg.body.status.err_code = static_cast(rc.get_code()); @@ -71,7 +71,7 @@ Status SharedMessage::SendStatus(const Status &rc) { Status SharedMessage::ReceiveStatus(Status *rc) { RETURN_UNEXPECTED_IF_NULL(rc); CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id"); - struct StatusMsgBuf msg {}; + struct CacheMsgBuf msg {}; auto err = msgrcv(msg_qid_, reinterpret_cast(&msg), sizeof(msg.body.status), 0, MSG_NOERROR); if (err == -1) { std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.h index 946104686a..c314236887 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.h @@ -28,7 +28,7 @@ namespace mindspore { namespace dataset { /// A message queue structure between the parent and the child process -struct StatusMsgBuf { +struct CacheMsgBuf { int64_t mtype; union { char mtext[1]; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.cc index ef45a8f877..cd25371336 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.cc @@ -168,12 +168,14 @@ Path CachePool::GetSpillPath() const { } CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { + tree_->LockShared(); // Prevent any node split while we search. CacheStat cs{-1, -1, 0, 0, 0, 0}; int64_t total_sz = 0; if (tree_->begin() != tree_->end()) { cs.min_key = tree_->begin().key(); cs.max_key = cs.min_key; // will adjust later. for (auto it = tree_->begin(); it != tree_->end(); ++it) { + it.LockShared(); total_sz += it.value().sz; if (it.value().ptr != nullptr) { ++cs.num_mem_cached; @@ -190,6 +192,7 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { } } cs.max_key = cur_key; + it.Unlock(); } } if (total_sz > 0) { @@ -199,6 +202,7 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { cs.average_cache_sz = 1; } } + tree_->Unlock(); return cs; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc index e38d138434..9fb7b0623f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc @@ -58,7 +58,7 @@ Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const Te if (sent_using_local_bypass) { MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data"; // Allocate shared memory from the server - auto mem_rq = std::make_shared(rq_.connection_id(), sz_); + auto mem_rq = std::make_shared(rq_.connection_id(), cc->GetClientId(), sz_); RETURN_IF_NOT_OK(cc->PushRequest(mem_rq)); RETURN_IF_NOT_OK(mem_rq->Wait()); addr_ = mem_rq->GetAddr(); @@ -305,6 +305,15 @@ Status GetStatRequest::PostReply() { return Status::OK(); } +Status GetCacheStateRequest::PostReply() { + try { + cache_service_state_ = std::stoi(reply_.result()); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + return Status::OK(); +} + Status ListSessionsRequest::PostReply() { auto *msg = flatbuffers::GetRoot(reply_.result().data()); auto session_vector = msg->sessions(); @@ -333,5 +342,13 @@ Status ServerStopRequest::PostReply() { return Status::OK(); } +BatchCacheRowsRequest::BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele) + : BaseRequest(RequestType::kBatchCacheRows) { + rq_.set_connection_id(cc->server_connection_id_); + rq_.set_client_id(cc->client_id_); + rq_.add_buf_data(cc->cookie()); + rq_.add_buf_data(std::to_string(addr)); + rq_.add_buf_data(std::to_string(num_ele)); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h index 5dd4d2d8cf..fb922d012a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h @@ -78,6 +78,9 @@ class BaseRequest { kListSessions = 16, kConnectReset = 17, kInternalFetchRow = 18, + kBatchCacheRows = 19, + kInternalCacheRow = 20, + kGetCacheState = 21, // Add new request before it. kRequestUnknown = 32767 }; @@ -133,10 +136,11 @@ class BaseRequest { class FreeSharedBlockRequest : public BaseRequest { public: friend class CacheServer; - explicit FreeSharedBlockRequest(connection_id_type connection_id, int64_t addr) + explicit FreeSharedBlockRequest(connection_id_type connection_id, int32_t client_id, int64_t addr) : BaseRequest(RequestType::kFreeSharedBlock) { rq_.set_connection_id(connection_id); rq_.add_buf_data(std::to_string(addr)); + rq_.set_client_id(client_id); } ~FreeSharedBlockRequest() override = default; }; @@ -178,7 +182,7 @@ class CacheRowRequest : public BaseRequest { /// the shared memory by sending another request. The following function will generate a suitable /// request for the CacheClient to send. std::shared_ptr GenerateFreeBlockRequest() { - return std::make_shared(rq_.connection_id(), addr_); + return std::make_shared(rq_.connection_id(), rq_.client_id(), addr_); } private: @@ -271,6 +275,24 @@ class GetStatRequest : public BaseRequest { CacheServiceStat stat_{}; }; +/// \brief Get the state of a cache service +class GetCacheStateRequest : public BaseRequest { + public: + friend class CacheServer; + explicit GetCacheStateRequest(connection_id_type connection_id) + : BaseRequest(RequestType::kGetCacheState), cache_service_state_(0) { + rq_.set_connection_id(connection_id); + } + ~GetCacheStateRequest() override = default; + + Status PostReply() override; + + auto GetState() const { return cache_service_state_; } + + private: + int8_t cache_service_state_; +}; + /// \brief Request to cache a schema class CacheSchemaRequest : public BaseRequest { public: @@ -367,10 +389,11 @@ class ListSessionsRequest : public BaseRequest { class AllocateSharedBlockRequest : public BaseRequest { public: friend class CacheServer; - explicit AllocateSharedBlockRequest(connection_id_type connection_id, size_t requestedSz) + explicit AllocateSharedBlockRequest(connection_id_type connection_id, int32_t client_id, size_t requestedSz) : BaseRequest(RequestType::kAllocateSharedBlock) { rq_.set_connection_id(connection_id); rq_.add_buf_data(std::to_string(requestedSz)); + rq_.set_client_id(client_id); } ~AllocateSharedBlockRequest() override = default; @@ -420,6 +443,13 @@ class ConnectResetRequest : public BaseRequest { return Status::OK(); } }; + +class BatchCacheRowsRequest : public BaseRequest { + public: + friend class CacheServer; + explicit BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele); + ~BatchCacheRowsRequest() override = default; +}; } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc index 0a15e48d65..0329d5f46a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc @@ -106,14 +106,18 @@ Status CacheServer::DoServiceStart() { RETURN_IF_NOT_OK(free_list_->Register(&vg_)); // Start the comm layer try { - comm_layer_ = std::make_shared(port_, shared_memory_sz_in_gb_); + comm_layer_ = std::make_shared(port_); RETURN_IF_NOT_OK(comm_layer_->Run()); - // Bring up a thread to monitor the unix socket in case it is removed. - auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get()); - RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f)); } catch (const std::exception &e) { RETURN_STATUS_UNEXPECTED(e.what()); } +#if CACHE_LOCAL_CLIENT + RETURN_IF_NOT_OK(CachedSharedMemory::CreateArena(&shm_, port_, shared_memory_sz_in_gb_)); + // Bring up a thread to monitor the unix socket in case it is removed. But it must be done + // after we have created the unix socket. + auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get()); + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f)); +#endif // Spawn a few threads to serve the real request. auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1); for (auto i = 0; i < num_workers_; ++i) { @@ -350,11 +354,12 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) { Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { auto connection_id = rq->connection_id(); + auto client_id = rq->client_id(); + CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set"); // Hold the shared lock to prevent the cache from being dropped. SharedLock lck(&rwLock_); CacheService *cs = GetService(connection_id); - auto shared_pool = comm_layer_->GetSharedMemoryPool(); - auto *base = shared_pool->SharedMemoryBaseAddr(); + auto *base = SharedMemoryBaseAddr(); // Ensure we got 3 pieces of data coming in CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() == 3, "Incomplete data"); // First piece of data is the cookie and is required @@ -381,8 +386,10 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); } } - // Return the block to the shared memory. - shared_pool->Deallocate(p); + // Return the block to the shared memory only if it is not internal request. + if (static_cast(rq->type()) == BaseRequest::RequestType::kCacheRow) { + DeallocateSharedMemory(client_id, p); + } return rc; } @@ -450,6 +457,7 @@ Status CacheServer::BatchFetch(const std::shared_ptrconnection_id(); + auto client_id = rq->client_id(); // Hold the shared lock to prevent the cache from being dropped. SharedLock lck(&rwLock_); CacheService *cs = GetService(connection_id); @@ -490,14 +498,13 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { reply->set_flag(local_bypass ? kDataIsInSharedMemory : 0); if (local_bypass) { // We will use shared memory - auto shared_pool = comm_layer_->GetSharedMemoryPool(); - auto *base = shared_pool->SharedMemoryBaseAddr(); + auto *base = SharedMemoryBaseAddr(); void *q = nullptr; - RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); + RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, mem_sz, &q)); WritableSlice dest(q, mem_sz); Status rc = BatchFetch(fbb, &dest); if (rc.IsError()) { - shared_pool->Deallocate(q); + DeallocateSharedMemory(client_id, q); return rc; } // We can't return the absolute address which makes no sense to the client. @@ -597,7 +604,7 @@ Status CacheServer::BuildPhaseDone(CacheRequest *rq) { // First piece of data is the cookie CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie"); auto &cookie = rq->buf_data(0); - // We can only allow to switch phase is the cookie match. + // We can only allow to switch phase if the cookie match. if (cookie == cs->cookie()) { RETURN_IF_NOT_OK(cs->BuildPhaseDone()); } else { @@ -713,6 +720,203 @@ Status CacheServer::ConnectReset(CacheRequest *rq) { return Status::OK(); } +Status CacheServer::BatchCacheRows(CacheRequest *rq) { + CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == 3, "Expect three pieces of data"); + int32_t numQ = GetNumGrpcWorkers(); + auto rng = GetRandomDevice(); + std::uniform_int_distribution distribution(0, numQ - 1); + int32_t qID = distribution(rng); + std::vector cache_rq_list; + try { + auto &cookie = rq->buf_data(0); + auto connection_id = rq->connection_id(); + auto client_id = rq->client_id(); + int64_t offset_addr; + int32_t num_elem; + auto *base = SharedMemoryBaseAddr(); + offset_addr = strtoll(rq->buf_data(1).data(), nullptr, 10); + auto p = reinterpret_cast(reinterpret_cast(base) + offset_addr); + num_elem = strtol(rq->buf_data(2).data(), nullptr, 10); + cache_rq_list.reserve(num_elem); + // Get a set of free request and push into the queues. + for (auto i = 0; i < num_elem; ++i) { + auto start = reinterpret_cast(p); + auto msg = GetTensorRowHeaderMsg(p); + p += msg->size_of_this(); + for (auto k = 0; k < msg->column()->size(); ++k) { + p += msg->data_sz()->Get(k); + } + CacheServerRequest *cache_rq; + RETURN_IF_NOT_OK(GetFreeRequestTag(qID++ % numQ, &cache_rq)); + cache_rq_list.push_back(cache_rq); + // Fill in details. + cache_rq->type_ = BaseRequest::RequestType::kInternalCacheRow; + cache_rq->st_ = CacheServerRequest::STATE::PROCESS; + cache_rq->rq_.set_connection_id(connection_id); + cache_rq->rq_.set_type(static_cast(cache_rq->type_)); + cache_rq->rq_.set_client_id(client_id); + cache_rq->rq_.set_flag(kDataIsInSharedMemory); + cache_rq->rq_.add_buf_data(cookie); + cache_rq->rq_.add_buf_data(std::to_string(start - reinterpret_cast(base))); + cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast(p - start))); + RETURN_IF_NOT_OK(PushRequest(GetRandomWorker(), cache_rq)); + } + // Now wait for all of them to come back. + Status rc; + for (CacheServerRequest *cache_rq : cache_rq_list) { + RETURN_IF_NOT_OK(cache_rq->Wait()); + if (cache_rq->rc_.IsError() && !cache_rq->rc_.IsInterrupted() && rc.IsOk()) { + rc = cache_rq->rc_; + } + RETURN_IF_NOT_OK(ReturnRequestTag(cache_rq)); + } + return rc; + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + return Status::OK(); +} + +void CacheServer::ProcessRequest(CacheServerRequest *cache_req) { + bool internal_request = false; + auto &rq = cache_req->rq_; + auto &reply = cache_req->reply_; + // Except for creating a new session, we expect cs is not null. + switch (cache_req->type_) { + case BaseRequest::RequestType::kCacheRow: + case BaseRequest::RequestType::kInternalCacheRow: { + // Look into the flag to see where we can find the data and + // call the appropriate method. + auto flag = rq.flag(); + if (BitTest(flag, kDataIsInSharedMemory)) { + cache_req->rc_ = FastCacheRow(&rq, &reply); + internal_request = (cache_req->type_ == BaseRequest::RequestType::kInternalCacheRow); + } else { + cache_req->rc_ = CacheRow(&rq, &reply); + } + break; + } + case BaseRequest::RequestType::kBatchCacheRows: { + cache_req->rc_ = BatchCacheRows(&rq); + break; + } + case BaseRequest::RequestType::kBatchFetchRows: { + cache_req->rc_ = BatchFetchRows(&rq, &reply); + break; + } + case BaseRequest::RequestType::kInternalFetchRow: { + internal_request = true; + auto connection_id = rq.connection_id(); + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); + if (cs == nullptr) { + std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; + cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot(rq.buf_data(0).data())); + } + break; + } + case BaseRequest::RequestType::kCreateCache: { + cache_req->rc_ = CreateService(&rq, &reply); + break; + } + case BaseRequest::RequestType::kGetCacheMissKeys: { + cache_req->rc_ = GetCacheMissKeys(&rq, &reply); + break; + } + case BaseRequest::RequestType::kDestroyCache: { + cache_req->rc_ = DestroyCache(&rq); + break; + } + case BaseRequest::RequestType::kGetStat: { + cache_req->rc_ = GetStat(&rq, &reply); + break; + } + case BaseRequest::RequestType::kCacheSchema: { + cache_req->rc_ = CacheSchema(&rq); + break; + } + case BaseRequest::RequestType::kFetchSchema: { + cache_req->rc_ = FetchSchema(&rq, &reply); + break; + } + case BaseRequest::RequestType::kBuildPhaseDone: { + cache_req->rc_ = BuildPhaseDone(&rq); + break; + } + case BaseRequest::RequestType::kDropSession: { + cache_req->rc_ = DestroySession(&rq); + break; + } + case BaseRequest::RequestType::kGenerateSessionId: { + cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply); + break; + } + case BaseRequest::RequestType::kAllocateSharedBlock: { + cache_req->rc_ = AllocateSharedMemory(&rq, &reply); + break; + } + case BaseRequest::RequestType::kFreeSharedBlock: { + cache_req->rc_ = FreeSharedMemory(&rq); + break; + } + case BaseRequest::RequestType::kStopService: { + // This command shutdowns everything. + // But we first reply back to the client that we receive the request. + // The real shutdown work will be done by the caller. + cache_req->rc_ = AcknowledgeShutdown(cache_req); + break; + } + case BaseRequest::RequestType::kHeartBeat: { + cache_req->rc_ = Status::OK(); + break; + } + case BaseRequest::RequestType::kToggleWriteMode: { + cache_req->rc_ = ToggleWriteMode(&rq); + break; + } + case BaseRequest::RequestType::kListSessions: { + cache_req->rc_ = ListSessions(&reply); + break; + } + case BaseRequest::RequestType::kConnectReset: { + cache_req->rc_ = ConnectReset(&rq); + break; + } + case BaseRequest::RequestType::kGetCacheState: { + auto connection_id = rq.connection_id(); + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); + if (cs == nullptr) { + std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; + cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto state = cs->GetState(); + reply.set_result(std::to_string(static_cast(state))); + cache_req->rc_ = Status::OK(); + } + break; + } + default: + std::string errMsg("Unknown request type : "); + errMsg += std::to_string(static_cast(cache_req->type_)); + cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } + // Notify it is done, and move on to the next request. + Status2CacheReply(cache_req->rc_, &reply); + cache_req->st_ = CacheServerRequest::STATE::FINISH; + // We will re-tag the request back to the grpc queue. Once it comes back from the client, + // the CacheServerRequest, i.e. the pointer cache_req, will be free + if (!internal_request) { + cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); + } else { + // This is an internal request and is not tied to rpc. But need to post because there + // is a thread waiting on the completion of this request. + cache_req->wp_.Set(); + } +} + /// \brief This is the main loop the cache server thread(s) are running. /// Each thread will pop a request and send the result back to the client using grpc /// \return @@ -722,121 +926,9 @@ Status CacheServer::ServerRequest(worker_id_t worker_id) { auto &my_que = cache_q_->operator[](worker_id); // Loop forever until we are interrupted or shutdown. while (!global_shutdown_) { - bool internal_request = false; CacheServerRequest *cache_req = nullptr; RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); - auto &rq = cache_req->rq_; - auto &reply = cache_req->reply_; - // Except for creating a new session, we expect cs is not null. - switch (cache_req->type_) { - case BaseRequest::RequestType::kCacheRow: { - // Look into the flag to see where we can find the data and - // call the appropriate method. - auto flag = rq.flag(); - if (BitTest(flag, kDataIsInSharedMemory)) { - cache_req->rc_ = FastCacheRow(&rq, &reply); - } else { - cache_req->rc_ = CacheRow(&rq, &reply); - } - break; - } - case BaseRequest::RequestType::kInternalFetchRow: { - internal_request = true; - auto connection_id = rq.connection_id(); - SharedLock lck(&rwLock_); - CacheService *cs = GetService(connection_id); - if (cs == nullptr) { - std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; - cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot(rq.buf_data(0).data())); - } - break; - } - case BaseRequest::RequestType::kCreateCache: { - cache_req->rc_ = CreateService(&rq, &reply); - break; - } - case BaseRequest::RequestType::kGetCacheMissKeys: { - cache_req->rc_ = GetCacheMissKeys(&rq, &reply); - break; - } - case BaseRequest::RequestType::kDestroyCache: { - cache_req->rc_ = DestroyCache(&rq); - break; - } - case BaseRequest::RequestType::kGetStat: { - cache_req->rc_ = GetStat(&rq, &reply); - break; - } - case BaseRequest::RequestType::kCacheSchema: { - cache_req->rc_ = CacheSchema(&rq); - break; - } - case BaseRequest::RequestType::kFetchSchema: { - cache_req->rc_ = FetchSchema(&rq, &reply); - break; - } - case BaseRequest::RequestType::kBuildPhaseDone: { - cache_req->rc_ = BuildPhaseDone(&rq); - break; - } - case BaseRequest::RequestType::kDropSession: { - cache_req->rc_ = DestroySession(&rq); - break; - } - case BaseRequest::RequestType::kGenerateSessionId: { - cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply); - break; - } - case BaseRequest::RequestType::kAllocateSharedBlock: { - cache_req->rc_ = AllocateSharedMemory(&rq, &reply); - break; - } - case BaseRequest::RequestType::kFreeSharedBlock: { - cache_req->rc_ = FreeSharedMemory(&rq); - break; - } - case BaseRequest::RequestType::kStopService: { - // This command shutdowns everything. - cache_req->rc_ = GlobalShutdown(cache_req); - break; - } - case BaseRequest::RequestType::kHeartBeat: { - cache_req->rc_ = Status::OK(); - break; - } - case BaseRequest::RequestType::kToggleWriteMode: { - cache_req->rc_ = ToggleWriteMode(&rq); - break; - } - case BaseRequest::RequestType::kListSessions: { - cache_req->rc_ = ListSessions(&reply); - break; - } - case BaseRequest::RequestType::kConnectReset: { - cache_req->rc_ = ConnectReset(&rq); - break; - } - default: - std::string errMsg("Unknown request type : "); - errMsg += std::to_string(static_cast(cache_req->type_)); - cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } - // Notify it is done, and move on to the next request. - Status2CacheReply(cache_req->rc_, &reply); - cache_req->st_ = CacheServerRequest::STATE::FINISH; - // We will re-tag the request back to the grpc queue. Once it comes back from the client, - // the CacheServerRequest, i.e. the pointer cache_req, will be free - if (!global_shutdown_) { - if (!internal_request) { - cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); - } else { - // This is an internal request and is not tied to rpc. But need to post because there - // is a thread waiting on the completion of this request. - cache_req->wp_.Set(); - } - } + ProcessRequest(cache_req); } return Status::OK(); } @@ -869,6 +961,11 @@ CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int MS_LOG(WARNING) << "Warning: This build is not compiled with numa support. Install libnuma-devel and use a build " "that is compiled with numa support for more optimal performance"; } + // We create the shared memory and we will destroy it. All other client just detach only. + if (shared_memory_sz_in_gb_ > kDefaultSharedMemorySize) { + MS_LOG(INFO) << "Shared memory size is readjust to " << kDefaultSharedMemorySize << " GB."; + shared_memory_sz_in_gb_ = kDefaultSharedMemorySize; + } } Status CacheServer::Run(int msg_qid) { @@ -965,24 +1062,34 @@ session_id_type CacheServer::GenerateSessionID() { } Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) { - auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10); - auto shared_pool = comm_layer_->GetSharedMemoryPool(); - auto *base = shared_pool->SharedMemoryBaseAddr(); - void *p = nullptr; - RETURN_IF_NOT_OK(shared_pool->Allocate(requestedSz, &p)); - // We can't return the absolute address which makes no sense to the client. - // Instead we return the difference. - auto difference = reinterpret_cast(p) - reinterpret_cast(base); - reply->set_result(std::to_string(difference)); + auto client_id = rq->client_id(); + CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set"); + try { + auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10); + void *p = nullptr; + RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, requestedSz, &p)); + auto *base = SharedMemoryBaseAddr(); + // We can't return the absolute address which makes no sense to the client. + // Instead we return the difference. + auto difference = reinterpret_cast(p) - reinterpret_cast(base); + reply->set_result(std::to_string(difference)); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } return Status::OK(); } Status CacheServer::FreeSharedMemory(CacheRequest *rq) { - auto shared_pool = comm_layer_->GetSharedMemoryPool(); - auto *base = shared_pool->SharedMemoryBaseAddr(); - auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10); - auto p = reinterpret_cast(reinterpret_cast(base) + addr); - shared_pool->Deallocate(p); + auto client_id = rq->client_id(); + CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set"); + auto *base = SharedMemoryBaseAddr(); + try { + auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10); + auto p = reinterpret_cast(reinterpret_cast(base) + addr); + DeallocateSharedMemory(client_id, p); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } return Status::OK(); } @@ -992,7 +1099,7 @@ Status CacheServer::RpcRequest(worker_id_t worker_id) { return Status::OK(); } -Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) { +Status CacheServer::AcknowledgeShutdown(CacheServerRequest *cache_req) { auto *rq = &cache_req->rq_; auto *reply = &cache_req->reply_; if (!rq->buf_data().empty()) { @@ -1008,9 +1115,10 @@ Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) { } } reply->set_result("OK"); - Status2CacheReply(cache_req->rc_, reply); - cache_req->st_ = CacheServerRequest::STATE::FINISH; - cache_req->responder_.Finish(*reply, grpc::Status::OK, cache_req); + return Status::OK(); +} + +void CacheServer::GlobalShutdown() { // Let's shutdown in proper order. bool expected = false; if (global_shutdown_.compare_exchange_strong(expected, true)) { @@ -1032,7 +1140,6 @@ Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) { it = all_caches_.erase(it); } } - return Status::OK(); } worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) const { @@ -1053,6 +1160,12 @@ worker_id_t CacheServer::GetRandomWorker() const { return dist(gen); } +Status CacheServer::AllocateSharedMemory(int32_t client_id, size_t sz, void **p) { + return shm_->AllocateSharedMemory(client_id, sz, p); +} + +void CacheServer::DeallocateSharedMemory(int32_t client_id, void *p) { shm_->DeallocateSharedMemory(client_id, p); } + Status CacheServer::Builder::IpcResourceCleanup() { Status rc; SharedMemory::shm_key_t shm_key; @@ -1124,8 +1237,8 @@ CacheServer::Builder::Builder() : top_("/tmp"), num_workers_(std::thread::hardware_concurrency() / 2), port_(50052), - shared_memory_sz_in_gb_(4), - memory_cap_ratio_(0.8) { + shared_memory_sz_in_gb_(kDefaultSharedMemorySize), + memory_cap_ratio_(kDefaultMemoryCapRatio) { if (num_workers_ == 0) { num_workers_ = 1; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h index 2fa07ee4d5..b6ac0d0f8e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h @@ -25,12 +25,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include "minddata/dataset/engine/cache/cache_arena.h" #include "minddata/dataset/engine/cache/cache_hw.h" #include "minddata/dataset/engine/cache/cache_numa.h" #include "minddata/dataset/engine/cache/cache_service.h" @@ -196,15 +198,31 @@ class CacheServer : public Service { /// \brief Check if we bind threads to numa cores bool IsNumaAffinityOn() const { return numa_affinity_; } - /// \brief Internal function to do row batch fetch - /// \param rq Request - /// \param reply Reply - /// \return Status object - Status BatchFetchRows(CacheRequest *rq, CacheReply *reply); - /// \brief Return the memory cap ratio float GetMemoryCapRatio() const { return memory_cap_ratio_; } + /// \brief How a request is handled. + /// \note that it can be process immediately by a grpc thread or routed to a server thread + /// which is pinned to some numa node core. + void ProcessRequest(CacheServerRequest *cache_req); + + void GlobalShutdown(); + + /// \brief This returns where we attach to the shared memory. + /// Some gRPC requests will ask for a shared memory block, and + /// we can't return the absolute address as this makes no sense + /// in the client. So instead we will return an address relative + /// to the base address of the shared memory where we attach to. + /// \return Base address of the shared memory. + const void *SharedMemoryBaseAddr() const { return shm_->SharedMemoryBaseAddr(); } + + /// \brief Return the public key of the shared memory. + int32_t GetKey() const { return shm_->GetKey(); } + + Status AllocateSharedMemory(int32_t client_id, size_t sz, void **p); + + void DeallocateSharedMemory(int32_t client_id, void *p); + private: static std::once_flag init_instance_flag_; static CacheServer *instance_; @@ -228,6 +246,7 @@ class CacheServer : public Service { std::map numa_tasks_; bool numa_affinity_; std::vector shutdown_qIDs_; + std::unique_ptr shm_; /// \brief Constructor /// \param spill_path Top directory for spilling buffers to. @@ -315,7 +334,7 @@ class CacheServer : public Service { /// \brief A proper shutdown of the server /// \return Status object - Status GlobalShutdown(CacheServerRequest *); + Status AcknowledgeShutdown(CacheServerRequest *cache_req); /// \brief Find keys that will be cache miss /// \return Status object @@ -332,12 +351,19 @@ class CacheServer : public Service { /// \brief Connect request by a pipeline Status ConnectReset(CacheRequest *rq); + /// \brief Internal function to do row batch fetch + /// \param rq Request + /// \param reply Reply + /// \return Status object + Status BatchFetchRows(CacheRequest *rq, CacheReply *reply); + /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. /// \param[in] v A vector of row id. /// \param[out] out A contiguous memory buffer that holds the requested rows. /// \return Status object Status BatchFetch(const std::shared_ptr &fbb, WritableSlice *out); + Status BatchCacheRows(CacheRequest *rq); }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc index ee6fac25be..4d746bcb2d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc @@ -68,10 +68,11 @@ Status CacheService::DoServiceStop() { Status CacheService::CacheRow(const std::vector &buf, row_id_type *row_id_generated) { SharedLock rw(&rw_lock_); RETURN_UNEXPECTED_IF_NULL(row_id_generated); - if (st_ == CacheServiceState::kFetchPhase) { + if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) { // For this kind of cache service, once we are done with the build phase into fetch phase, we can't // allow other to cache more rows. - RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " + + std::to_string(static_cast(st_.load()))); } if (st_ == CacheServiceState::kNoLocking) { // We ignore write this request once we turn off locking on the B+ tree. So we will just @@ -119,6 +120,16 @@ Status CacheService::CacheRow(const std::vector &buf, row_id_type if (rc == Status(StatusCode::kDuplicateKey)) { MS_LOG(DEBUG) << "Ignoring duplicate key."; } else { + if (HasBuildPhase()) { + // For cache service that has a build phase, record the error in the state + // so other clients can be aware of the new state. There is nothing one can + // do to resume other than to drop the cache. + if (rc.IsNoSpace()) { + st_ = CacheServiceState::kNoSpace; + } else if (rc.IsOutofMemory()) { + st_ = CacheServiceState::kOutOfMemory; + } + } RETURN_IF_NOT_OK(rc); } return Status::OK(); @@ -130,10 +141,11 @@ Status CacheService::CacheRow(const std::vector &buf, row_id_type Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) { SharedLock rw(&rw_lock_); RETURN_UNEXPECTED_IF_NULL(row_id_generated); - if (st_ == CacheServiceState::kFetchPhase) { + if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) { // For this kind of cache service, once we are done with the build phase into fetch phase, we can't // allow other to cache more rows. - RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " + + std::to_string(static_cast(st_.load()))); } if (st_ == CacheServiceState::kNoLocking) { // We ignore write this request once we turn off locking on the B+ tree. So we will just @@ -161,6 +173,16 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_ if (rc == Status(StatusCode::kDuplicateKey)) { MS_LOG(DEBUG) << "Ignoring duplicate key."; } else { + if (HasBuildPhase()) { + // For cache service that has a build phase, record the error in the state + // so other clients can be aware of the new state. There is nothing one can + // do to resume other than to drop the cache. + if (rc.IsNoSpace()) { + st_ = CacheServiceState::kNoSpace; + } else if (rc.IsOutofMemory()) { + st_ = CacheServiceState::kOutOfMemory; + } + } RETURN_IF_NOT_OK(rc); } return Status::OK(); @@ -202,16 +224,17 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) { SharedLock rw(&rw_lock_); RETURN_UNEXPECTED_IF_NULL(out); out->stat_ = cp_->GetStat(); - out->state_ = static_cast(st_); + out->state_ = static_cast(st_.load()); return Status::OK(); } Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector &v, const std::shared_ptr &fbb) { SharedLock rw(&rw_lock_); - if (st_ == CacheServiceState::kBuildPhase) { + if (HasBuildPhase() && st_ != CacheServiceState::kFetchPhase) { // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. - RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " + + std::to_string(static_cast(st_.load()))); } std::vector> datalocator_v; datalocator_v.reserve(v.size()); @@ -271,7 +294,8 @@ Status CacheService::FetchSchema(std::string *out) const { SharedLock rw(&rw_lock_); if (st_ == CacheServiceState::kBuildPhase) { // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. - RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " + + std::to_string(static_cast(st_.load()))); } RETURN_UNEXPECTED_IF_NULL(out); // We are going to use std::string to allocate and hold the result which will be eventually @@ -292,6 +316,7 @@ Status CacheService::BuildPhaseDone() { UniqueLock rw(&rw_lock_); st_ = CacheServiceState::kFetchPhase; cp_->SetLocking(false); + MS_LOG(WARNING) << "Locking mode is switched off."; return Status::OK(); } else { RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase"); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h index d1b10f7c4d..593c6f9941 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h @@ -91,6 +91,8 @@ class CacheService : public Service { /// \param[in/out] A pointer to a pre-allocated ServiceStat structure /// \return Status Object Status GetStat(ServiceStat *); + /// \brief Return the current state + CacheServiceState GetState() const { return st_.load(); } /// \brief Cache schema /// \param buf A Google Flatbuffer that contains the schema /// \param len size of the buffer @@ -131,7 +133,7 @@ class CacheService : public Service { bool generate_id_; std::string cookie_; std::atomic num_clients_; - CacheServiceState st_; + std::atomic st_; std::string schema_; std::shared_ptr numa_pool_; // We also cache the result from calling FindKeysMiss because it is expensive. Besides user make diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc index acf2b29028..90d821addc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc @@ -427,6 +427,7 @@ Status CachePerfRun::Run() { } // Now we create the children knowing all two sets of message queues are constructed. + auto start_tick = std::chrono::steady_clock::now(); for (auto i = 0; i < num_pipelines_; ++i) { auto pid = fork(); if (pid == 0) { @@ -502,6 +503,10 @@ Status CachePerfRun::Run() { // Wait until all pipelines finish the first epoch. RETURN_IF_NOT_OK(pipeline_wp_.Wait()); + auto end_tick = std::chrono::steady_clock::now(); + + int64_t elapse_time = std::chrono::duration_cast(end_tick - start_tick).count(); + std::cout << "Epoch one (build phase) elapsed time " << elapse_time << " seconds" << std::endl; std::cout << "Epoch one (build phase) per pipeline per worker summary. Buffer size = " << cfg_.rows_per_buffer() << std::endl; @@ -543,6 +548,7 @@ Status CachePerfRun::Run() { epoch_sync_cnt_ = 0; pipeline_wp_.Clear(); epoch_results_.clear(); + start_tick = std::chrono::steady_clock::now(); // Signal each pipeline to start for (auto msg_qid : msg_send_lists_) { CachePerfMsg msg; @@ -551,6 +557,9 @@ Status CachePerfRun::Run() { } // Wait for the child to finish RETURN_IF_NOT_OK(pipeline_wp_.Wait()); + end_tick = std::chrono::steady_clock::now(); + elapse_time = std::chrono::duration_cast(end_tick - start_tick).count(); + std::cout << "Epoch " << epoch_num << " elapsed time " << elapse_time << " seconds" << std::endl; std::cout << "Epoch " << epoch_num << " (read phase) per pipeline per worker summary. Buffer size = " << cc_->GetPrefetchSize() << std::endl; PrintEpochSummary(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc index 280bc08052..b1a48518ad 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc @@ -238,6 +238,9 @@ Status CachePipelineRun::RunFirstEpoch() { RETURN_IF_NOT_OK(pTask->Join(Task::WaitFlag::kBlocking)); } + // Final flush + cc_->FlushAsyncWriteBuffer(); + // Send a message saying epoch one done for this pipeline. EpochDone proto; proto.set_pipeline(my_pipeline_); @@ -291,7 +294,7 @@ Status CachePipelineRun::WriterWorkerEntry(int32_t worker_id) { buffer->set_tensor_table(std::move(tensor_table)); // Measure the time to call WriteBuffer auto start_tick = std::chrono::steady_clock::now(); - rc = cc_->WriteBuffer(std::move(buffer)); + rc = cc_->AsyncWriteBuffer(std::move(buffer)); auto end_tick = std::chrono::steady_clock::now(); if (rc.IsError()) { if (rc.IsOutofMemory() || rc.IsNoSpace()) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc index d9908b2003..76b74ee6d7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc @@ -122,6 +122,17 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { if (db_ptr->eoe()) { // Ignore it. MS_LOG(DEBUG) << "Ignore eoe"; + // However we need to flush any left over from the async write buffer. But any error + // we are getting will just to stop caching but the pipeline will continue + Status rc; + if ((rc = cache_client_->FlushAsyncWriteBuffer()).IsError()) { + cache_missing_rows_ = false; + if (rc.IsOutofMemory() || rc.IsNoSpace()) { + cache_client_->ServerRunningOutOfResources(); + } else { + MS_LOG(INFO) << "Async row flushing not successful: " << rc.ToString(); + } + } } else { while (db_ptr->NumRows() > 0) { TensorRow row; @@ -143,6 +154,9 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { rc = rq->AsyncSendCacheRequest(cache_client_, row); if (rc.IsOk()) { RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); + } else if (rc.IsOutofMemory() || rc.IsNoSpace()) { + cache_missing_rows_ = false; + cache_client_->ServerRunningOutOfResources(); } } } @@ -309,17 +323,25 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha if (st_.compare_exchange_strong(expected, State::kDirty)) { // We will do a deep copy but write directly into CacheRequest protobuf or shared memory Status rc; - cleaner_copy_ = std::make_shared(cc.get()); - rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row); - if (rc.IsOk()) { - // Send the request async. The cleaner will check the return code. - rc = cc->PushRequest(cleaner_copy_); + rc = cc->AsyncWriteRow(row); + if (rc.get_code() == StatusCode::kNotImplementedYet) { + cleaner_copy_ = std::make_shared(cc.get()); + rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row); + if (rc.IsOk()) { + // Send the request async. The cleaner will check the return code. + rc = cc->PushRequest(cleaner_copy_); + } + } else if (rc.IsOk()) { + // Set the state to clean even though it still sits in the cache client async buffer. + // The cleaner will then ignore it once the state is clean. + st_ = State::kClean; } if (rc.IsError()) { // Clean up the shared pointer and reset the state back to empty cleaner_copy_.reset(); st_ = State::kEmpty; } + return rc; } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc index 0408eab1bf..d73b2c5228 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc @@ -109,7 +109,14 @@ Status CacheOp::CacheAllRows(int32_t worker_id) { RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); while (!db_ptr->eof()) { if (!db_ptr->eoe()) { - RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr))); + Status rc; + // Do the Async write if we attach to the shared memory. + rc = cache_client_->AsyncWriteBuffer(std::move(db_ptr)); + if (rc.get_code() == StatusCode::kNotImplementedYet) { + RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr))); + } else if (rc.IsError()) { + return rc; + } } else { // In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up // as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the @@ -139,21 +146,41 @@ Status CacheOp::WaitForCachingAllRows() { RETURN_IF_NOT_OK(rows_cache_done_.Wait()); // Move from build phase to fetch phase if we are the one to fill the cache if (phase_ == Phase::kBuildPhase) { + RETURN_IF_NOT_OK(cache_client_->FlushAsyncWriteBuffer()); // One more flush RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone()); // Move to the next phase phase_ = Phase::kFetchPhase; } - // Get statistics from the server, and if we are not the one to create the cache, + // If we are not the one to create the cache, // wait until the state changed from build phase to fetch base. - CacheServiceStat stat{}; bool BuildPhaseDone = true; do { - RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); - BuildPhaseDone = stat.cache_service_state == static_cast(CacheServiceState::kFetchPhase); - if (!BuildPhaseDone) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + int8_t out; + RETURN_IF_NOT_OK(cache_client_->GetState(&out)); + auto state = static_cast(out); + switch (state) { + case CacheServiceState::kBuildPhase: + // Do nothing. Continue to wait. + BuildPhaseDone = false; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + break; + case CacheServiceState::kFetchPhase: + BuildPhaseDone = true; + break; + case CacheServiceState::kOutOfMemory: + return Status(StatusCode::kOutOfMemory, "Cache server is running out of memory"); + case CacheServiceState::kNoSpace: + return Status(StatusCode::kNoSpace, "Cache server is running of out spill storage"); + case CacheServiceState::kNone: + case CacheServiceState::kError: + default: + RETURN_STATUS_UNEXPECTED("Unexpected state: " + std::to_string(out)); } } while (!BuildPhaseDone); + // Get statistics from the server, and if we are not the one to create the cache, + // wait until the state changed from build phase to fetch base. + CacheServiceStat stat{}; + RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); const row_id_type min_key = stat.min_row_id; const row_id_type max_key = stat.max_row_id; num_rows_ = max_key - min_key + 1; diff --git a/mindspore/ccsrc/minddata/dataset/util/btree.h b/mindspore/ccsrc/minddata/dataset/util/btree.h index 69723ac2f6..b7a355c296 100644 --- a/mindspore/ccsrc/minddata/dataset/util/btree.h +++ b/mindspore/ccsrc/minddata/dataset/util/btree.h @@ -148,6 +148,12 @@ class BPlusTree { acquire_lock_ = on_off; } + void LockShared() { rw_lock_.LockShared(); } + + void LockExclusive() { rw_lock_.LockExclusive(); } + + void Unlock() { rw_lock_.Unlock(); } + private: // Abstract class of a node (leaf or inner) class BaseNode { @@ -409,6 +415,21 @@ class BPlusTree { bool operator==(const Iterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); } bool operator!=(const Iterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); } + void LockShared() { + cur_->rw_lock_.LockShared(); + locked_ = true; + } + + void LockExclusive() { + cur_->rw_lock_.LockExclusive(); + locked_ = true; + } + + void Unlock() { + cur_->rw_lock_.Unlock(); + locked_ = false; + } + private: typename BPlusTree::LeafNode *cur_; slot_type slot_; @@ -458,6 +479,21 @@ class BPlusTree { bool operator==(const ConstIterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); } bool operator!=(const ConstIterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); } + void LockShared() { + cur_->rw_lock_.LockShared(); + locked_ = true; + } + + void LockExclusive() { + cur_->rw_lock_.LockExclusive(); + locked_ = true; + } + + void Unlock() { + cur_->rw_lock_.Unlock(); + locked_ = false; + } + private: const typename BPlusTree::LeafNode *cur_; slot_type slot_;