From 572c5e5f2928cafa91b9ab625f4fe8768c2d9b2a Mon Sep 17 00:00:00 2001 From: Lixia Chen Date: Thu, 20 Aug 2020 12:27:39 -0400 Subject: [PATCH] Rebase up to 542a52fbf8781d1b7df71f6b06a72847a12c5b66 --- .../dataset/engine/cache/CMakeLists.txt | 26 +- .../dataset/engine/cache/cache_admin_arg.cc | 21 +- .../dataset/engine/cache/cache_admin_arg.h | 3 +- .../dataset/engine/cache/cache_client.cc | 44 +- .../dataset/engine/cache/cache_client.h | 9 +- .../dataset/engine/cache/cache_common.h | 24 +- .../dataset/engine/cache/cache_grpc.proto | 5 +- .../dataset/engine/cache/cache_grpc_server.cc | 30 +- .../dataset/engine/cache/cache_grpc_server.h | 2 + .../minddata/dataset/engine/cache/cache_hw.cc | 220 +++++++ .../minddata/dataset/engine/cache/cache_hw.h | 81 +++ .../dataset/engine/cache/cache_main.cc | 2 + .../dataset/engine/cache/cache_numa.cc | 224 +++++++ .../dataset/engine/cache/cache_numa.h | 195 ++++++ .../{util => engine/cache}/cache_pool.cc | 150 +++-- .../{util => engine/cache}/cache_pool.h | 34 +- .../dataset/engine/cache/cache_request.cc | 70 ++- .../dataset/engine/cache/cache_request.h | 50 +- .../dataset/engine/cache/cache_server.cc | 390 ++++++++---- .../dataset/engine/cache/cache_server.h | 85 ++- .../dataset/engine/cache/cache_service.cc | 229 +++---- .../dataset/engine/cache/cache_service.h | 40 +- .../dataset/engine/cache/de_tensor.fbs | 24 +- .../dataset/engine/cache/perf/CMakeLists.txt | 32 + .../dataset/engine/cache/perf/cache_msg.cc | 48 ++ .../dataset/engine/cache/perf/cache_msg.h | 78 +++ .../dataset/engine/cache/perf/cache_perf.cc | 39 ++ .../engine/cache/perf/cache_perf.proto | 39 ++ .../engine/cache/perf/cache_perf_run.cc | 575 ++++++++++++++++++ .../engine/cache/perf/cache_perf_run.h | 100 +++ .../engine/cache/perf/cache_pipeline.cc | 44 ++ .../engine/cache/perf/cache_pipeline_run.cc | 471 ++++++++++++++ .../engine/cache/perf/cache_pipeline_run.h | 117 ++++ .../cache}/storage_container.cc | 2 +- .../cache}/storage_container.h | 0 .../{util => engine/cache}/storage_manager.cc | 2 +- .../{util => engine/cache}/storage_manager.h | 2 +- .../engine/datasetops/cache_base_op.cc | 39 +- .../dataset/engine/datasetops/cache_base_op.h | 1 - .../engine/datasetops/cache_merge_op.cc | 3 +- .../dataset/engine/datasetops/cache_op.cc | 2 +- .../engine/opt/pre/cache_error_pass.cc | 78 ++- .../dataset/engine/opt/pre/cache_error_pass.h | 73 +++ .../minddata/dataset/util/CMakeLists.txt | 3 - .../ccsrc/minddata/dataset/util/allocator.h | 5 + mindspore/ccsrc/minddata/dataset/util/path.h | 12 + mindspore/ccsrc/minddata/dataset/util/task.cc | 13 +- mindspore/ccsrc/minddata/dataset/util/task.h | 15 +- tests/ut/cpp/dataset/cache_op_test.cc | 1 - tests/ut/python/cachetests/cachetest_py.sh | 11 +- tests/ut/python/dataset/test_cache_map.py | 213 +++++-- tests/ut/python/dataset/test_cache_nomap.py | 85 +++ tests/ut/python/test_server_stop_testcase.sh | 10 + 53 files changed, 3545 insertions(+), 526 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.h rename mindspore/ccsrc/minddata/dataset/{util => engine/cache}/cache_pool.cc (65%) rename mindspore/ccsrc/minddata/dataset/{util => engine/cache}/cache_pool.h (82%) create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/perf/CMakeLists.txt create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_msg.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_msg.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf.proto create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.h rename mindspore/ccsrc/minddata/dataset/{util => engine/cache}/storage_container.cc (98%) rename mindspore/ccsrc/minddata/dataset/{util => engine/cache}/storage_container.h (100%) rename mindspore/ccsrc/minddata/dataset/{util => engine/cache}/storage_manager.cc (98%) rename mindspore/ccsrc/minddata/dataset/{util => engine/cache}/storage_manager.h (97%) create mode 100755 tests/ut/python/test_server_stop_testcase.sh diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt index 802fbf3779..a0b4382dfd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(perf EXCLUDE_FROM_ALL) include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) @@ -5,6 +6,18 @@ ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engin file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +# Try to find numa header file and its library +find_file(NUMA_HDR NAMES "numa.h") + +if (EXISTS ${NUMA_HDR}) + ADD_DEFINITIONS(-DNUMA_ENABLED) + MESSAGE("Numa package found") +endif () + +if (${CMAKE_SYSTEM_NAME} MATCHES "Linux") + ADD_DEFINITIONS(-DCACHE_LOCAL_CLIENT) +endif () + add_library(engine-cache-client OBJECT cache_client.cc cache_fbb.cc @@ -20,8 +33,13 @@ if (ENABLE_CACHE) ${CACHE_GRPC_SRCS} cache_grpc_server.cc cache_arena.cc + cache_hw.cc + cache_numa.cc + cache_pool.cc cache_service.cc - cache_server.cc) + cache_server.cc + storage_manager.cc + storage_container.cc) add_executable(cache_server cache_main.cc) target_link_libraries(cache_server @@ -39,6 +57,10 @@ if (ENABLE_CACHE) target_link_libraries(cache_server mindspore::glog) endif () + if (EXISTS ${NUMA_HDR}) + target_link_libraries(cache_server numa) + endif () + add_executable(cache_admin cache_admin.cc cache_admin_arg.cc) target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES}) @@ -49,7 +71,7 @@ if (ENABLE_CACHE) add_dependencies(engine-cache-server generated_engine_files) else () - ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PRTO_HDRS cache_grpc.proto) + ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PROTO_HDRS cache_grpc.proto) target_sources(engine-cache-client PUBLIC ${CACHE_PROTO_SRCS}) endif () diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc index 4fef953b2d..a7e26b08de 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -31,7 +32,9 @@ namespace mindspore { namespace dataset { - +const int32_t CacheAdminArgHandler::kDefaultNumWorkers = std::thread::hardware_concurrency() > 2 + ? std::thread::hardware_concurrency() / 2 + : 1; const char CacheAdminArgHandler::kServerBinary[] = "cache_server"; const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp"; @@ -304,8 +307,10 @@ Status CacheAdminArgHandler::Validate() { } // Additional checks here - if (num_workers_ < 1 || num_workers_ > 100) - return Status(StatusCode::kSyntaxError, "Number of workers must be in range of 1 and 100."); + auto max_num_workers = std::max(std::thread::hardware_concurrency(), 100); + if (num_workers_ < 1 || num_workers_ > max_num_workers) + return Status(StatusCode::kSyntaxError, + "Number of workers must be in range of 1 and " + std::to_string(max_num_workers) + "."); if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3)."); if (memory_cap_ratio_ <= 0 || memory_cap_ratio_ > 1) return Status(StatusCode::kSyntaxError, "Memory cap ratio should be positive and no greater than 1"); @@ -354,13 +359,15 @@ Status CacheAdminArgHandler::RunCommand() { std::vector session_info = rq->GetSessionCacheInfo(); if (!session_info.empty()) { std::cout << std::setw(12) << "Session" << std::setw(12) << "Cache Id" << std::setw(12) << "Mem cached" - << std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::endl; + << std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" << std::setw(10) << "Numa hit" + << std::endl; for (auto curr_session : session_info) { std::string cache_id; std::string stat_mem_cached; std::string stat_disk_cached; std::string stat_avg_cached; - int32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF); + std::string stat_numa_hit; + uint32_t crc = (curr_session.connection_id & 0x00000000FFFFFFFF); cache_id = (curr_session.connection_id == 0) ? "n/a" : std::to_string(crc); stat_mem_cached = (curr_session.stats.num_mem_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_mem_cached); @@ -368,10 +375,12 @@ Status CacheAdminArgHandler::RunCommand() { (curr_session.stats.num_disk_cached == 0) ? "n/a" : std::to_string(curr_session.stats.num_disk_cached); stat_avg_cached = (curr_session.stats.avg_cache_sz == 0) ? "n/a" : std::to_string(curr_session.stats.avg_cache_sz); + stat_numa_hit = + (curr_session.stats.num_numa_hit == 0) ? "n/a" : std::to_string(curr_session.stats.num_numa_hit); std::cout << std::setw(12) << curr_session.session_id << std::setw(12) << cache_id << std::setw(12) << stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached - << std::endl; + << std::setw(10) << stat_numa_hit << std::endl; } } else { std::cout << "No active sessions." << std::endl; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h index 5a78ebf0c7..020cb1b415 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "minddata/dataset/util/status.h" #include "minddata/dataset/engine/cache/cache_client.h" @@ -29,7 +30,7 @@ namespace dataset { class CacheAdminArgHandler { public: - static constexpr int32_t kDefaultNumWorkers = 32; + static const int32_t kDefaultNumWorkers; static constexpr int32_t kDefaultSharedMemorySizeInGB = 4; static constexpr int32_t kDefaultLogLevel = 1; static constexpr float kMemoryCapRatio = 0.8; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc index 077304c1c9..9c23fec1a1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc @@ -17,7 +17,6 @@ #include #include "minddata/dataset/engine/cache/cache_client.h" #include "minddata/dataset/engine/cache/cache_request.h" -#include "minddata/dataset/engine/cache/cache_service.h" #include "minddata/dataset/engine/cache/cache_fbb.h" #include "minddata/dataset/util/bit.h" @@ -59,6 +58,7 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool : server_connection_id_(0), cache_mem_sz_(cache_mem_sz), spill_(spill), + client_id_(-1), local_bypass_(false), hostname_(std::move(hostname)), port_(port), @@ -71,6 +71,22 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool CacheClient::~CacheClient() { cache_miss_keys_wp_.Set(); + if (client_id_ != -1) { + try { + // Send a message to the server, saying I am done. + auto rq = std::make_shared(server_connection_id_, client_id_); + Status rc = PushRequest(rq); + if (rc.IsOk()) { + rc = rq->Wait(); + if (rc.IsOk()) { + MS_LOG(INFO) << "Disconnect from server successful"; + } + } + } catch (const std::exception &e) { + // Can't do anything in destructor. So just log the error. + MS_LOG(ERROR) << e.what(); + } + } (void)comm_->ServiceStop(); } @@ -85,7 +101,7 @@ void CacheClient::Print(std::ostream &out) const { } Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { - auto rq = std::make_shared(server_connection_id_, cookie(), SupportLocalClient()); + auto rq = std::make_shared(this); RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row)); RETURN_IF_NOT_OK(PushRequest(rq)); RETURN_IF_NOT_OK(rq->Wait()); @@ -104,7 +120,7 @@ Status CacheClient::WriteBuffer(std::unique_ptr &&in) const { for (auto i = 0; i < num_rows; ++i) { TensorRow row; RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); - arr[i] = std::make_shared(server_connection_id_, cookie(), SupportLocalClient()); + arr[i] = std::make_shared(this); RETURN_IF_NOT_OK(arr[i]->SerializeCacheRowRequest(this, row)); RETURN_IF_NOT_OK(PushRequest(arr[i])); } @@ -118,7 +134,7 @@ Status CacheClient::WriteBuffer(std::unique_ptr &&in) const { Status CacheClient::GetRows(const std::vector &row_id, TensorTable *out) const { RETURN_UNEXPECTED_IF_NULL(out); - auto rq = std::make_shared(server_connection_id_, row_id, SupportLocalClient()); + auto rq = std::make_shared(this, row_id); RETURN_IF_NOT_OK(PushRequest(rq)); RETURN_IF_NOT_OK(rq->Wait()); int64_t mem_addr; @@ -167,7 +183,7 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock. CacheServiceStat stat{}; RETURN_IF_NOT_OK(GetStat(&stat)); - if (stat.cache_service_state == static_cast(CacheService::State::kFetchPhase)) { + if (stat.cache_service_state == static_cast(CacheServiceState::kFetchPhase)) { return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); } } else { @@ -183,18 +199,16 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { // Start the comm layer to receive reply RETURN_IF_NOT_OK(comm_->ServiceStart()); // Initiate connection - auto rq = std::make_shared(cinfo_, cache_mem_sz_, createFlag); + auto rq = std::make_shared(this, cinfo_, cache_mem_sz_, createFlag); RETURN_IF_NOT_OK(PushRequest(rq)); Status rc = rq->Wait(); - if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) { - std::string cookie; - rq->ParseResult(&server_connection_id_, &cookie); - if (rc.IsOk()) { - // The 1st guy creating the cache will get a cookie back. - // But this object may be shared among pipelines and we don't want - // overwrite it. - cookie_ = cookie; - } + bool success = (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey); + // If we get kDuplicateKey, it just means we aren't the first one to create the cache, + // and we will continue to parse the result. + if (rc.get_code() == StatusCode::kDuplicateKey) { + RETURN_IF_NOT_OK(rq->PostReply()); + } + if (success) { // Attach to shared memory for local client RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_)); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h index eec0ed7dfd..7f3a64938e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h @@ -47,6 +47,9 @@ namespace dataset { class CacheClient { public: friend class CacheMergeOp; + friend class CreateCacheRequest; + friend class CacheRowRequest; + friend class BatchFetchRequest; /// \brief A builder to help creating a CacheClient object class Builder { @@ -115,7 +118,7 @@ class CacheClient { session_id_type GetSessionId() const { return session_id_; } uint64_t GetCacheMemSz() const { return cache_mem_sz_; } bool isSpill() const { return spill_; } - const std::string &getHostname() const { return hostname_; } + const std::string &GetHostname() const { return hostname_; } int32_t GetPort() const { return port_; } int32_t GetNumConnections() const { return num_connections_; } int32_t GetPrefetchSize() const { return prefetch_size_; } @@ -256,8 +259,10 @@ class CacheClient { CacheClientInfo cinfo_; // The server_connection_id_ is the actual id we use for operations after the cache is built connection_id_type server_connection_id_; - // Some magic cookie returned from the cache server. + // Some magic cookie/id returned from the cache server. std::string cookie_; + int32_t client_id_; + std::vector cpu_list_; // Comm layer bool local_bypass_; std::string hostname_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h index 22761d099f..894f6e5714 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h @@ -20,11 +20,6 @@ /// both client and server side codes. Do not put code that is not common here. /// There are client and server specific header files. -// On platform like Windows, we may support only tcp/ip clients -#if !defined(_WIN32) && !defined(_WIN64) -#define CACHE_LOCAL_CLIENT 1 -#endif - #ifdef ENABLE_CACHE #include #endif @@ -50,6 +45,9 @@ constexpr static uint32_t kDataIsInSharedMemory = 2; /// \brief Size of each message used in message queue. constexpr static int32_t kSharedMessageSize = 2048; +/// \brief State of CacheService at the server. +enum class CacheServiceState : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking }; + /// \brief Convert a Status object into a protobuf /// \param rc[in] Status object /// \param reply[in/out] pointer to pre-allocated protobuf object @@ -61,6 +59,22 @@ inline void Status2CacheReply(const Status &rc, CacheReply *reply) { /// \param port /// \return unix socket url inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); } + +/// \brief Round up to the next 4k +inline int64_t round_up_4K(int64_t sz) { + // Since 4096 is a power of 2, a simple way to round up is add 4095 and mask off all the + // bits of 4095 + return static_cast(sz + 4095) & ~static_cast(4095); +} + +/// Memory policy +enum CachePoolPolicy : int8_t { kOnNode, kPreferred, kLocal, kInterleave, kNone }; + +/// Misc typedef +using worker_id_t = int32_t; +using numa_id_t = int32_t; +using cpu_id_t = int32_t; + } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto index 68619d33ab..1ec829adb2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto @@ -32,12 +32,13 @@ message CacheRequest { uint32 flag = 2; oneof connect_info { // The server_connection_id is the actual id we use for operations after the cache is built - int64 connection_id = 3; + uint64 connection_id = 3; // But some request like CreateCache we have to use the session id and crc to connect to the server. CacheClientInfo connection_info = 4; } + int32 client_id = 5; // Everything else is just vector of buffers - repeated bytes buf_data = 5; + repeated bytes buf_data = 6; } message CacheReply { diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc index 892880f0ea..20804adaec 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc @@ -74,6 +74,9 @@ Status CacheServerGreeterImpl::Run() { #if CACHE_LOCAL_CLIENT RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); MS_LOG(INFO) << "Creation of local socket and shared memory successful"; + auto cs = CacheServer::GetInstance().GetHWControl(); + // This shared memory is a hot memory and we will interleave among all the numa nodes. + cs->InterleaveMemory(const_cast(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L); #endif } else { std::string errMsg = "Fail to start server. "; @@ -127,8 +130,13 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp st_ = STATE::PROCESS; svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this); } else if (st_ == STATE::PROCESS) { + auto &cs = CacheServer::GetInstance(); // Get a new tag and handle the next request before we serve the current request. - // The tag will be recycled when its state is changed to FINISH + // The tag will be recycled when its state is changed to FINISH. + // The number of free list queues is the same as the number of grpc threads. + // Where we get the free list it doesn't matter (as long we return it back to the right queue). + // We can round robin, use the qid or even use the worker id. We will use the free list queue + // where the current request comes from. CacheServerRequest *next_rq; RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq)); RETURN_IF_NOT_OK((*next_rq)(svc, cq)); @@ -138,8 +146,24 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp type_ = static_cast(rq_.type()); // Now we pass the address of this instance to CacheServer's main loop. MS_LOG(DEBUG) << "Handle request " << *this; - auto &cs = CacheServer::GetInstance(); - RETURN_IF_NOT_OK(cs.PushRequest(myQID, this)); + // We will distribute the request evenly (or randomly) over all the numa nodes. + // The exception is BatchFetch which we need to pre-process here. + if (type_ == BaseRequest::RequestType::kBatchFetchRows) { + rc_ = cs.BatchFetchRows(&rq_, &reply_); + if (!rc_.IsInterrupted()) { + Status2CacheReply(rc_, &reply_); + st_ = CacheServerRequest::STATE::FINISH; + responder_.Finish(reply_, grpc::Status::OK, this); + } else { + return rc_; + } + } else { + // When the number of grpc workers is the same as the server workers, we will use this queue id + // and push to the corresponding queue. + bool random = cs.GetNumWorkers() != cs.GetNumGrpcWorkers(); + worker_id_t worker_id = random ? cs.GetRandomWorker() : myQID; + RETURN_IF_NOT_OK(cs.PushRequest(worker_id, this)); + } } else if (st_ == STATE::FINISH) { MS_LOG(DEBUG) << *this << " Finished."; // Return back to the free list. diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h index b3d1cc5f70..a85ec8da07 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_ +#include #include #include #include @@ -34,6 +35,7 @@ namespace dataset { class CacheServerRequest : public BaseRequest { public: friend class CacheServer; + friend class CacheService; enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; explicit CacheServerRequest(int32_t queue_id) : BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown), diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc new file mode 100644 index 0000000000..b429c209a6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.cc @@ -0,0 +1,220 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/cache/cache_hw.h" +#ifdef NUMA_ENABLED +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +namespace mindspore { +namespace dataset { +CacheServerHW::CacheServerHW() { + num_cpus_ = std::thread::hardware_concurrency(); + MS_LOG(DEBUG) << "Number of cpu(s) : " << num_cpus_; +#ifdef NUMA_ENABLED + if (numa_enabled()) { + MS_LOG(WARNING) << "Numa support enabled"; + for (auto i = 0; i <= numa_max_node(); ++i) { + int64_t free_avail; + int64_t mem_avail = numa_node_size(i, &free_avail); + MS_LOG(INFO) << "Total physical/free RAM in bytes at node " << i << " : " << mem_avail << "/" << free_avail; + } + } +#endif +} + +int64_t CacheServerHW::GetTotalSystemMemory() { + auto pages = sysconf(_SC_PHYS_PAGES); + auto page_size = sysconf(_SC_PAGE_SIZE); + auto total = static_cast(pages) * static_cast(page_size); + MS_LOG(INFO) << "Total physical RAM in bytes: " << total; + return total; +} + +Status CacheServerHW::SetDefaultMemoryPolicy(CachePoolPolicy policy) { +#ifdef NUMA_ENABLED + if (numa_enabled()) { + // Set our default memory policy. + switch (policy) { + case kLocal: + numa_set_localalloc(); + MS_LOG(DEBUG) << "Setting memory default policy to local node. Low level code may override the setting"; + break; + case kInterleave: + numa_set_interleave_mask(numa_all_nodes_ptr); + MS_LOG(DEBUG) << "Numa affinity is turned off. Use interleave memory policy as default."; + break; + case kOnNode: + case kPreferred: + RETURN_STATUS_UNEXPECTED("Unsupported memory policy"); + break; + case kNone: + default: + // No action taken. + break; + } + } +#endif + return Status::OK(); +} + +Status CacheServerHW::GetNumaNodeInfo() { + std::set numa_nodes_; + Path node(kSysNodePath); + auto it = Path::DirIterator::OpenDirectory(&node); + if (it == nullptr) { + MS_LOG(WARNING) << "Unable to open directory " << kSysNodePath << ". Skip scanning hardware info"; + return Status::OK(); + } + auto isdigit_string = [](const char *str) -> bool { + bool r = true; + for (auto i = 0; i < strlen(str); ++i) { + if (!std::isdigit(str[i])) { + r = false; + break; + } + } + return r; + }; + // Look for name starts with 'node' and followed by digits. + const char kNodeName[] = "node"; + while (it->hasNext()) { + auto p = it->next(); + const std::string entry = p.Basename(); + const char *name = entry.data(); + if (strncmp(name, kNodeName, 4) == 0 && isdigit_string(name + strlen(kNodeName))) { + numa_nodes_.insert(p); + } + } + // There should be at least one. But if not found in any case, just move on the + // rest of the server start up. + if (numa_nodes_.empty()) { + MS_LOG(WARNING) << "No numa nodes ? Skip scanning hardware info"; + return Status::OK(); + } + // For each numa node, get a list of CPU that is associated with it. + const char kCpuList[] = "cpulist"; + auto r = std::regex("[0-9]*-[0-9]*"); + for (Path p : numa_nodes_) { + auto node_dir = p.Basename().data(); + numa_id_t numa_node = strtol(node_dir + strlen(kNodeName), nullptr, 10); + Path f = p / kCpuList; + std::ifstream fs(f.toString()); + CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + f.toString()); + std::string cpu_string; + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + int32_t cpu_cnt = 0; + while (getline(fs, cpu_string)) { + // Now we parse the content of cpu_string. + std::sregex_iterator iter(cpu_string.begin(), cpu_string.end(), r); + std::sregex_iterator end; + while (iter != end) { + auto match = iter->str(); + auto pos = match.find_first_of('-'); + std::string min = match.substr(0, pos); + std::string max = match.substr(pos + 1); + cpu_id_t cpu_min = strtol(min.data(), nullptr, 10); + cpu_id_t cpu_max = strtol(max.data(), nullptr, 10); + MS_LOG(DEBUG) << "Numa node " << numa_node << " CPU(s) : " << cpu_min << "-" << cpu_max; + for (int i = cpu_min; i <= cpu_max; ++i) { + CPU_SET(i, &cpuset); + ++cpu_cnt; + } + ++iter; + } + } + CHECK_FAIL_RETURN_UNEXPECTED(!fs.bad(), "Fail to read file: " + f.toString()); + fs.close(); + // Remember which cpu is attached to this numa node. + numa_cpuset_.emplace(numa_node, cpuset); + numa_cpu_cnt_.emplace(numa_node, cpu_cnt); + } + MS_LOG(DEBUG) << "Number of numa nodes : " << numa_cpuset_.size(); + return Status::OK(); +} + +Status CacheServerHW::SetAffinity(const Task &tk, numa_id_t numa_node) { + auto r = numa_cpuset_.find(numa_node); + if (r != numa_cpuset_.end()) { + auto err = pthread_setaffinity_np(tk.GetNativeHandle(), sizeof(r->second), &r->second); + if (err) { + std::string errMsg = "Unable to set affiity. Errno = " + std::to_string(errno); + RETURN_STATUS_UNEXPECTED(errMsg); + } + } else { + RETURN_STATUS_UNEXPECTED("Numa node " + std::to_string(numa_node) + " not found"); + } + return Status::OK(); +} + +std::vector CacheServerHW::GetCpuList(numa_id_t numa_id) { + std::vector v; + auto it = numa_cpuset_.find(numa_id); + if (it != numa_cpuset_.end()) { + auto &cpu_set = it->second; + for (auto i = 0; i < num_cpus_; ++i) { + if (CPU_ISSET(i, &cpu_set)) { + v.push_back(i); + } + } + } + return v; +} + +numa_id_t CacheServerHW::GetMyNode() const { + numa_id_t node_id = 0; + auto cpu = sched_getcpu(); +#ifdef NUMA_ENABLED + node_id = numa_node_of_cpu(cpu); +#else + bool found = false; + for (auto it : numa_cpuset_) { + cpu_set_t &cpu_set = it.second; + if (CPU_ISSET(cpu, &cpu_set)) { + node_id = it.first; + found = true; + break; + } + } + MS_LOG(DEBUG) << "cpu id " << cpu << " found : " << std::boolalpha << found; +#endif + return node_id; +} + +void CacheServerHW::InterleaveMemory(void *ptr, size_t sz) { +#ifdef NUMA_ENABLED + if (numa_enabled()) { + numa_interleave_memory(ptr, sz, numa_all_nodes_ptr); + } +#endif +} + +bool CacheServerHW::numa_enabled() { +#ifdef NUMA_ENABLED + return (numa_available() != -1); +#else + return false; +#endif +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.h new file mode 100644 index 0000000000..586cb5ad8c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_hw.h @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_ + +#ifdef NUMA_ENABLED +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/cache/cache_common.h" +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/task.h" + +namespace mindspore { +namespace dataset { +class CacheServerHW { + public: + CacheServerHW(); + ~CacheServerHW() = default; + + /// \brief Get Numa node info without using numa library + /// \return Status object + Status GetNumaNodeInfo(); + + /// \brief Set thread affinity + Status SetAffinity(const Task &tk, numa_id_t numa_node); + + /// \brief Get total number of cpu(s) + int32_t GetCpuCount() const { return num_cpus_; } + + /// \brief Get total number of numa nodes + int32_t GetNumaNodeCount() const { return numa_cpuset_.empty() ? 1 : numa_cpuset_.size(); } + + /// \brief Get a list of cpu for a given numa node. + std::vector GetCpuList(numa_id_t numa_id); + + static bool numa_enabled(); + + /// \brief Return the numa the current thread is running on. + numa_id_t GetMyNode() const; + + /// \brief Interleave a given memory block. Used by shared memory only. + static void InterleaveMemory(void *ptr, size_t sz); + + /// \brief Set default memory policy. + static Status SetDefaultMemoryPolicy(CachePoolPolicy); + + /// \brief This returns the size (in bytes) of the physical RAM on the machine. + /// \return the size (in bytes) of the physical RAM on the machine. + static int64_t GetTotalSystemMemory(); + + private: + constexpr static char kSysNodePath[] = "/sys/devices/system/node"; + int32_t num_cpus_; + std::map numa_cpuset_; + std::map numa_cpu_cnt_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_HW_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc index 956db70246..868f5f445b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc @@ -54,6 +54,8 @@ ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds:: #endif try { rq->set_type(static_cast(type)); + rq->set_client_id(-1); + rq->set_flag(0); grpc::ChannelArguments args; grpc::ClientContext ctx; grpc::CompletionQueue cq; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.cc new file mode 100644 index 0000000000..35ddc1df9d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.cc @@ -0,0 +1,224 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include +#include +#include +#include "minddata/dataset/engine/cache/cache_hw.h" +#include "minddata/dataset/engine/cache/cache_numa.h" +#include "minddata/dataset/util/random.h" +namespace mindspore { +namespace dataset { +NumaMemoryPool::NumaMemoryPool(std::shared_ptr hw, float memory_cap_ratio) + : hw_(std::move(hw)), memory_cap_ratio_(memory_cap_ratio) { + int64_t total_avail = 0; + // We will create a number of small Arenas to spread out the server threads so it + // will be less contention. If we link with the numa library, i.e. if + // NUMA_ENABLED is defined, we will make use of the low level numa library such that + // each Arena solely comes from one particular socket. + // The total number of Arenas will be controlled under the number of cpus. + auto num_cpus = hw_->GetCpuCount(); + memory_segments_.reserve(num_cpus); + arena_list_.reserve(num_cpus); + mux_ = std::make_unique(num_cpus); + auto num_memory_nodes = num_cpus; + int64_t max_avail = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_; + int64_t arena_sz = max_avail / num_memory_nodes; + // If arena_sz is too small, lower the number of Arenas. + if (arena_sz < std::numeric_limits::max()) { + arena_sz = round_up_4K(std::numeric_limits::max()); + num_memory_nodes = max_avail / arena_sz; + if (num_memory_nodes == 0) { + num_memory_nodes = 1; + arena_sz = max_avail; + } + } + MS_LOG(INFO) << "Creating " << num_memory_nodes << " number of arena. Each one of size " << arena_sz; + +#ifdef NUMA_ENABLED + if (numa_available() != -1) { + auto num_numa_nodes = hw_->GetNumaNodeCount(); + numa_id_t node_id = 0; + for (auto i = 0; i < num_memory_nodes; ++i) { + auto success = CreateMultipleArenas(arena_sz, node_id++ % num_numa_nodes, 1); + total_avail += success * arena_sz; + } + } else { + auto success = CreateMultipleArenas(arena_sz, 0, num_memory_nodes); + total_avail += success * arena_sz; + } +#else + auto success = CreateMultipleArenas(arena_sz, 0, num_memory_nodes); + total_avail += success * arena_sz; +#endif + memory_cap_ = total_avail; + MS_LOG(WARNING) << "Memory pool created. Total available memory " << memory_cap_ << " spread in " << nodes_.size() + << " arenas"; + int32_t slot = 0; + // Set up a map for future easy access. + for (auto node_id : nodes_) { + numa_map_[node_id].push_back(slot); + ++slot; + } +} + +int32_t NumaMemoryPool::CreateMultipleArenas(int64_t segment_sz, numa_id_t node_id, int32_t repeat_count) { + int32_t success = 0; + for (auto i = 0; i < repeat_count; ++i) { +#ifdef NUMA_ENABLED + void *ptr = numa_alloc_onnode(segment_sz, node_id); +#else + void *ptr = malloc(segment_sz); +#endif + if (ptr != nullptr) { + memory_segments_.emplace_back(ptr, segment_sz); + arena_list_.push_back(std::make_unique(ptr, segment_sz)); + nodes_.push_back(node_id); + ++success; + } else { + // Skip the rest. + break; + } + } + MS_LOG(DEBUG) << "Allocate " << success << " arenas from node " << node_id; + return success; +} + +NumaMemoryPool::~NumaMemoryPool() { + if (!memory_segments_.empty()) { + for (auto &s : memory_segments_) { +#ifdef NUMA_ENABLED + numa_free(s.first, s.second); +#else + free(s.first); +#endif + } + } +} + +Status NumaMemoryPool::Allocate(size_t n, void **p) { + RETURN_UNEXPECTED_IF_NULL(p); + auto mt = GetRandomDevice(); + Status rc; + void *ptr = nullptr; + auto num_segments = memory_segments_.size(); + CHECK_FAIL_RETURN_UNEXPECTED(num_segments > 0, "No numa nodes available"); + if (NumaAware()) { + auto num_numa_nodes = hw_->GetNumaNodeCount(); + // We will start from the numa node this worker id is running on and do a round robin search. + numa_id_t start = hw_->GetMyNode(); + numa_id_t node_id = start; + do { + auto it = numa_map_.find(node_id); + if (it != numa_map_.end()) { + auto &slots = it->second; + auto num_slots = slots.size(); + std::uniform_int_distribution distribution(0, num_slots - 1); + auto start_slot = distribution(mt); + int32_t inx = start_slot; + do { + int32_t k = slots.at(inx); + std::unique_lock lock_x(mux_[k]); + auto &impl = arena_list_.at(k); + rc = impl->Allocate(n, &ptr); + if (rc.IsOk()) { + *p = ptr; + break; + } else if (rc.IsOutofMemory()) { + inx = (inx + 1) % num_slots; + } else { + return rc; + } + } while (inx != start_slot); + } + // We have done searching for this numa node. If not found, move to the next node. + if (ptr == nullptr) { + node_id = (node_id + 1) % num_numa_nodes; + } else { + break; + } + } while (node_id != start); + } else { + // If not numa aware, just randomly pick a slot. + std::uniform_int_distribution distribution(0, num_segments - 1); + auto start_slot = distribution(mt); + int32_t slot = start_slot; + do { + std::unique_lock lock_x(mux_[slot]); + auto &impl = arena_list_.at(slot); + rc = impl->Allocate(n, &ptr); + if (rc.IsOk()) { + *p = ptr; + break; + } else if (rc.IsOutofMemory()) { + // Make the next arena and continue. + slot = (slot + 1) % num_segments; + } else { + return rc; + } + } while (slot != start_slot); + } + // Handle the case we have done one round robin search. + if (ptr == nullptr) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } + return rc; +} + +void NumaMemoryPool::Deallocate(void *p) { + // Find out which numa slot it comes from. + auto slot = Locate(p); + MS_ASSERT(slot != -1); + std::unique_lock lock_x(mux_[slot]); + auto &impl = arena_list_.at(slot); + impl->Deallocate(p); +} + +int NumaMemoryPool::PercentFree() const { + int percent_free = 0; + int num_arena = 0; + for (auto const &p : arena_list_) { + percent_free += p->PercentFree(); + num_arena++; + } + if (num_arena) { + return percent_free / num_arena; + } else { + return 100; + } +} + +int32_t NumaMemoryPool::Locate(void *p) const { + int32_t slot = 0; + char *mem = reinterpret_cast(p); + for (slot = 0; slot < memory_segments_.size(); ++slot) { + auto elem = memory_segments_.at(slot); + char *q = reinterpret_cast(elem.first); + if (mem >= q && mem < q + elem.second) { + return slot; + } + } + return -1; +} + +std::vector NumaMemoryPool::GetAvailableNodes() const { + std::vector v; + std::transform(numa_map_.begin(), numa_map_.end(), std::back_inserter(v), + [](const std::pair> &v) { return v.first; }); + return v; +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.h new file mode 100644 index 0000000000..a9352c4826 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_numa.h @@ -0,0 +1,195 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_ + +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/cache/cache_hw.h" +#include "minddata/dataset/util/arena.h" +#include "minddata/dataset/util/memory_pool.h" + +namespace mindspore { +namespace dataset { +/// \brief An allocator but for a particular numa node. +template +class NumaAllocator { + public: + explicit NumaAllocator(numa_id_t node_id, CachePoolPolicy policy) + : policy_(policy), numa_enabled_(false), node_id_(node_id) { +#ifdef NUMA_ENABLED + numa_enabled_ = numa_available() != -1; +#endif + } + ~NumaAllocator() = default; + + template + explicit NumaAllocator(NumaAllocator const &rhs) + : policy_(rhs.policy_), numa_enabled_(rhs.numa_enabled_), node_id_(rhs.node_id_) {} + + template + bool operator==(Allocator const &rhs) const { + return node_id_ == rhs.node_id_; + } + + template + bool operator!=(Allocator const &rhs) const { + return node_id_ != rhs.node_id_; + } + + template + friend class NumaAllocator; + + using value_type = T; + using pointer = T *; + using const_pointer = const T *; + using reference = T &; + using const_reference = const T &; + using size_type = uint64_t; + using difference_type = std::ptrdiff_t; + + template + struct rebind { + using other = Allocator; + }; + + using propagate_on_container_copy_assignment = std::true_type; + using propagate_on_container_move_assignment = std::true_type; + using propagate_on_container_swap = std::true_type; + + /// Allocate memory on this node only. Return nullptr if no memory on this numa node. + /// \note. This version will not throw if we can't allocate memory from this node. + /// User must check if the pointer returned is null or not. + pointer allocate(std::size_t n) noexcept { + auto sz = n * sizeof(T); + void *p = nullptr; +#ifdef NUMA_ENABLED + if (numa_enabled_) { + switch (policy_) { + case kPreferred: + numa_set_preferred(node_id_); + p = numa_alloc(sz); + break; + case kLocal: + p = numa_alloc_local(sz); + break; + case kInterleave: + p = numa_alloc_interleaved(sz); + break; + case kOnNode: + p = numa_alloc_onnode(sz, node_id_); + break; + case kNone: + default: + p = numa_alloc(sz); + break; + } + } else { + p = malloc(sz); + } +#else + p = malloc(sz); +#endif + return reinterpret_cast(p); + } + + /// Free a memory allocated on this node. + void deallocate(pointer p, std::size_t n) noexcept { +#ifdef NUMA_ENABLED + if (numa_enabled_) { + numa_free(p, n * sizeof(T)); + } else { + free(p); + } +#else + free(p); +#endif + } + + /// \brief Allow one to change to another numa node + void SetNodeId(numa_id_t node_id) { node_id_ = node_id; } + + /// \brif Getter for node_id; + numa_id_t GetNodeId() const { return node_id_; } + + /// \brief Getter for policy + CachePoolPolicy GetPolicy() const { return policy_; } + + private: + CachePoolPolicy policy_; + bool numa_enabled_; + numa_id_t node_id_; +}; + +/// \brief A NumaMemoryPool is like a CircularPool but all the arenas have already been allocated +/// and each one comes from a numa socket. Memory is allocated using OnNode policy. That is, +/// it is solely comes from one particular numa node, and is not interleaved. +class NumaMemoryPool : public MemoryPool { + public: + explicit NumaMemoryPool(std::shared_ptr hw, float memory_cap_ratio); + ~NumaMemoryPool() override; + + // As a derived class, we override the following functions + Status Allocate(size_t size, void **pVoid) override; + void Deallocate(void *pVoid) override; + Status Reallocate(void **pVoid, size_t old_sz, size_t new_sz) override { RETURN_STATUS_UNEXPECTED("Not supported"); } + uint64_t get_max_size() const override { return std::numeric_limits::max(); } + int PercentFree() const override; + + /// \brief Return if the memory pool is numa aware + bool NumaAware() const { return CacheServerHW::numa_enabled(); } + + /// \brief. This returns all the numa nodes that we are able to allocate memory from. + std::vector GetAvailableNodes() const; + + /// \brief. Given a pointer (allocated from this pool), return the numa node where it is located. + /// \note. -1 is returned if not found. + numa_id_t FindNode(void *p) const { + auto slot = Locate(p); + if (slot != -1) { + return nodes_.at(slot); + } else { + return -1; + } + } + + /// \brief Return maximum available memory + int64_t GetAvailableMemory() const { return memory_cap_; } + + private: + std::shared_ptr hw_; + float memory_cap_ratio_; + int64_t memory_cap_; + std::vector> memory_segments_; + std::vector> arena_list_; + std::unique_ptr mux_; + std::vector nodes_; + std::map> numa_map_; + + /// \brief. Returns the slot that a given memory comes from. + /// \return slot from numa_segments. -1 if not found. + int32_t Locate(void *p) const; + + /// If numa library is not linked, or numa_availble() return -1, we will fall back to this method. + int32_t CreateMultipleArenas(int64_t segment_sz, numa_id_t node_id, int32_t repeat_count); +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_NUMA_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.cc similarity index 65% rename from mindspore/ccsrc/minddata/dataset/util/cache_pool.cc rename to mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.cc index 5972ada022..ae8ea73af9 100644 --- a/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.cc @@ -15,18 +15,14 @@ */ #include #include "utils/ms_utils.h" -#include "minddata/dataset/util/cache_pool.h" +#include "minddata/dataset/engine/cache/cache_pool.h" +#include "minddata/dataset/engine/cache/cache_server.h" #include "minddata/dataset/util/services.h" namespace mindspore { namespace dataset { -CachePool::CachePool(const value_allocator &alloc, bool ourOwnArena, const std::string &root) - : alloc_(alloc), - root_(root), - subfolder_(Services::GetUniqueID()), - sm_(nullptr), - tree_(nullptr), - custom_arena_(ourOwnArena) {} +CachePool::CachePool(std::shared_ptr mp, const std::string &root) + : mp_(std::move(mp)), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {} Status CachePool::DoServiceStart() { tree_ = std::make_shared(); @@ -36,10 +32,11 @@ Status CachePool::DoServiceStart() { RETURN_IF_NOT_OK(spill.CreateDirectories()); sm_ = std::make_shared(spill); RETURN_IF_NOT_OK(sm_->ServiceStart()); - MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString()); + MS_LOG(INFO) << "CachePool will use disk folder: " << spill.toString(); } return Status::OK(); } + Status CachePool::DoServiceStop() { Status rc; Status rc2; @@ -50,14 +47,14 @@ Status CachePool::DoServiceStop() { } } sm_.reset(); - // If it is our own arena, skip freeing individual pieces. - if (!custom_arena_) { - for (auto &bl : *tree_) { - if (bl.ptr != nullptr) { - alloc_.deallocate(bl.ptr, bl.sz); - } + + value_allocator alloc(mp_); + for (auto &bl : *tree_) { + if (bl.ptr != nullptr) { + alloc.deallocate(bl.ptr, bl.sz); } } + tree_.reset(); if (!root_.toString().empty()) { Path spill = GetSpillPath(); @@ -75,8 +72,10 @@ Status CachePool::DoServiceStop() { } return rc2; } + CachePool::~CachePool() noexcept { (void)ServiceStop(); } -Status CachePool::Insert(CachePool::key_type key, const std::vector &buf, bool writeToDiskDirectly) { + +Status CachePool::Insert(CachePool::key_type key, const std::vector &buf) { DataLocator bl; Status rc; size_t sz = 0; @@ -85,26 +84,35 @@ Status CachePool::Insert(CachePool::key_type key, const std::vectorAllocate(sz, reinterpret_cast(&bl.ptr)); + if (rc.IsOk()) { + // Write down which numa node where we allocate from. It only make sense if the policy is kOnNode. + if (CacheServerHW::numa_enabled()) { + auto &cs = CacheServer::GetInstance(); + auto node_id = cs.GetHWControl()->GetMyNode(); + bl.node_id = mp_->FindNode(bl.ptr); + CHECK_FAIL_RETURN_UNEXPECTED(bl.node_id != -1, "Allocator is not from numa memory pool"); + bl.node_hit = (bl.node_id == node_id); + } + // We will do a piecewise copy. + WritableSlice dest(bl.ptr, bl.sz); + size_t pos = 0; + for (auto &v : buf) { + WritableSlice out(dest, pos); + rc = WritableSlice::Copy(&out, v); if (rc.IsError()) { - alloc_.deallocate(bl.ptr, sz); - bl.ptr = nullptr; - return rc; + break; } - } else if (sm_ != nullptr) { + pos += v.GetSize(); + } + if (rc.IsError()) { + mp_->Deallocate(bl.ptr); + bl.ptr = nullptr; + return rc; + } + } else if (rc.IsOutofMemory()) { + // If no memory, write to disk. + if (sm_ != nullptr) { MS_LOG(DEBUG) << "Spill to disk directly ... " << bl.sz << " bytes."; RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); } else { @@ -112,12 +120,8 @@ Status CachePool::Insert(CachePool::key_type key, const std::vectorWrite(&bl.storage_key, buf)); - } else { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); - } + } else { + return rc; } // Insert into the B+ tree. We may still get out of memory error. So need to catch it. try { @@ -127,10 +131,13 @@ Status CachePool::Insert(CachePool::key_type key, const std::vectorDeallocate(bl.ptr); + bl.ptr = nullptr; + return rc; } return rc; } + Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const { RETURN_UNEXPECTED_IF_NULL(dest); auto r = tree_->Search(key); @@ -156,13 +163,14 @@ Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *byt } return Status::OK(); } -const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; } + Path CachePool::GetSpillPath() const { auto spill = Path(root_) / subfolder_; return spill; } + CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { - CacheStat cs{-1, -1, 0, 0, 0}; + CacheStat cs{-1, -1, 0, 0, 0, 0}; int64_t total_sz = 0; if (tree_->begin() != tree_->end()) { cs.min_key = tree_->begin().key(); @@ -174,6 +182,9 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { } else { ++cs.num_disk_cached; } + if (it.value().node_hit) { + ++cs.num_numa_hit; + } auto cur_key = it.key(); if (GetMissingKeys) { for (auto i = cs.max_key + 1; i < cur_key; ++i) { @@ -192,49 +203,26 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const { } return cs; } -Status CachePool::Spill(CachePool::DataLocator *dl) { - if (sm_ == nullptr) { - RETURN_STATUS_UNEXPECTED("No disk storage to spill"); - } - RETURN_UNEXPECTED_IF_NULL(dl); - RETURN_UNEXPECTED_IF_NULL(dl->ptr); - if (dl->storage_key == 0) { - ReadableSlice data(dl->ptr, dl->sz); - RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data})); - } - alloc_.deallocate(dl->ptr, dl->sz); - dl->ptr = nullptr; - return Status::OK(); -} -Status CachePool::Locate(CachePool::DataLocator *dl) { - RETURN_UNEXPECTED_IF_NULL(dl); - if (dl->ptr == nullptr) { - if (sm_ == nullptr) { - RETURN_STATUS_UNEXPECTED("No disk storage to locate the data"); - } - try { - dl->ptr = alloc_.allocate(dl->sz); - WritableSlice dest(dl->ptr, dl->sz); - Status rc = Read(dl->storage_key, &dest); - if (rc.IsError()) { - alloc_.deallocate(dl->ptr, dl->sz); - dl->ptr = nullptr; - return rc; - } - } catch (const std::bad_alloc &e) { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); - } - } - return Status::OK(); -} -size_t CachePool::GetSize(CachePool::key_type key) const { + +Status CachePool::GetDataLocator(key_type key, const std::shared_ptr &fbb, + flatbuffers::Offset *out) const { + RETURN_UNEXPECTED_IF_NULL(out); auto r = tree_->Search(key); if (r.second) { auto &it = r.first; - return it->sz; + DataLocatorMsgBuilder bld(*fbb); + bld.add_key(key); + bld.add_size(it->sz); + bld.add_node_id(it->node_id); + bld.add_addr(reinterpret_cast(it->ptr)); + auto offset = bld.Finish(); + *out = offset; } else { - return 0; + // Key not in the cache. + auto offset = CreateDataLocatorMsg(*fbb, key, 0, 0, 0); + *out = offset; } + return Status::OK(); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/cache_pool.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.h similarity index 82% rename from mindspore/ccsrc/minddata/dataset/util/cache_pool.h rename to mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.h index 77c1c06f24..cdd6d05f4c 100644 --- a/mindspore/ccsrc/minddata/dataset/util/cache_pool.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_pool.h @@ -19,11 +19,14 @@ #include #include #include +#include #include +#include "minddata/dataset/engine/cache/cache_common.h" +#include "minddata/dataset/engine/cache/cache_numa.h" +#include "minddata/dataset/engine/cache/storage_manager.h" #include "minddata/dataset/util/allocator.h" #include "minddata/dataset/util/service.h" #include "minddata/dataset/util/slice.h" -#include "minddata/dataset/util/storage_manager.h" #include "minddata/dataset/util/auto_index.h" #include "minddata/dataset/util/btree.h" @@ -45,13 +48,15 @@ class CachePool : public Service { // An internal class to locate the whereabouts of a backed up buffer which can be either in class DataLocator { public: - DataLocator() : ptr(nullptr), sz(0), storage_key(0) {} + DataLocator() : ptr(nullptr), sz(0), node_id(0), node_hit(false), storage_key(0) {} ~DataLocator() = default; DataLocator(const DataLocator &other) = default; DataLocator &operator=(const DataLocator &other) = default; DataLocator(DataLocator &&other) noexcept { ptr = other.ptr; sz = other.sz; + node_id = other.node_id; + node_hit = other.node_hit; storage_key = other.storage_key; other.ptr = nullptr; other.sz = 0; @@ -61,6 +66,8 @@ class CachePool : public Service { if (&other != this) { ptr = other.ptr; sz = other.sz; + node_id = other.node_id; + node_hit = other.node_hit; storage_key = other.storage_key; other.ptr = nullptr; other.sz = 0; @@ -70,6 +77,8 @@ class CachePool : public Service { } pointer ptr; size_t sz; + numa_id_t node_id; // where the numa node the memory is allocated to + bool node_hit; // we can allocate to the preferred node StorageManager::key_type storage_key; }; @@ -85,19 +94,20 @@ class CachePool : public Service { int64_t num_mem_cached; int64_t num_disk_cached; int64_t average_cache_sz; + int64_t num_numa_hit; std::vector gap; }; /// \brief Constructor /// \param alloc Allocator to allocate memory from /// \param root Optional disk folder to spill - explicit CachePool(const value_allocator &alloc, bool customArena, const std::string &root = ""); + explicit CachePool(std::shared_ptr mp, const std::string &root = ""); CachePool(const CachePool &) = delete; CachePool(CachePool &&) = delete; CachePool &operator=(const CachePool &) = delete; CachePool &operator=(CachePool &&) = delete; - ~CachePool() noexcept; + ~CachePool() noexcept override; Status DoServiceStart() override; Status DoServiceStop() override; @@ -110,7 +120,8 @@ class CachePool : public Service { /// \param[in] buf A sequence of ReadableSlice objects. /// \param[in] writeToDiskDirectly If true, no spill to disk if spill is enabled, or return no memory /// \return Error code - Status Insert(key_type key, const std::vector &buf, bool writeToDiskDirectly); + Status Insert(CachePool::key_type key, const std::vector &buf); + /// \brief Restore a cached buffer (from memory or disk) /// \param[in] key A previous key returned from Insert /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice @@ -118,18 +129,14 @@ class CachePool : public Service { /// \return Error code Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const; - Status Spill(DataLocator *dl); - - Status Locate(DataLocator *dl); - - size_t GetSize(key_type key) const; + /// \brief Serialize a DataLocator + Status GetDataLocator(key_type, const std::shared_ptr &, + flatbuffers::Offset *) const; /// \brief Get statistics. /// \return CacheStat object CacheStat GetStat(bool GetMissingKeys = false) const; - const value_allocator &get_allocator() const; - std::string MyName() const { return subfolder_; } /// \brief Toggle locking @@ -137,12 +144,11 @@ class CachePool : public Service { void SetLocking(bool on_off) { tree_->SetLocking(on_off); } private: - value_allocator alloc_; + std::shared_ptr mp_; Path root_; const std::string subfolder_; std::shared_ptr sm_; std::shared_ptr tree_; - bool custom_arena_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc index fe2641b10f..fdec89a590 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc @@ -14,6 +14,11 @@ * limitations under the License. */ #include "minddata/dataset/engine/cache/cache_request.h" +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) +#include +#include +#include +#endif #include #include #include "minddata/dataset/core/constants.h" @@ -106,6 +111,7 @@ Status CacheRowRequest::PostReply() { } return Status::OK(); } + Status CacheRowRequest::Prepare() { if (BitTest(rq_.flag(), kDataIsInSharedMemory)) { // First one is cookie, followed by address and then size. @@ -118,10 +124,21 @@ Status CacheRowRequest::Prepare() { return Status::OK(); } -BatchFetchRequest::BatchFetchRequest(connection_id_type connection_id, const std::vector &row_id, - bool local_bypass) - : BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(local_bypass), row_id_(row_id) { - rq_.set_connection_id(connection_id); +CacheRowRequest::CacheRowRequest(const CacheClient *cc) + : BaseRequest(RequestType::kCacheRow), + support_local_bypass_(cc->local_bypass_), + addr_(-1), + sz_(0), + row_id_from_server_(-1) { + rq_.set_connection_id(cc->server_connection_id_); + rq_.set_client_id(cc->client_id_); + rq_.add_buf_data(cc->cookie_); +} + +BatchFetchRequest::BatchFetchRequest(const CacheClient *cc, const std::vector &row_id) + : BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(cc->local_bypass_), row_id_(row_id) { + rq_.set_connection_id(cc->server_connection_id_); + rq_.set_client_id(cc->client_id_); rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0); // Convert the row id into a flatbuffer flatbuffers::FlatBufferBuilder fbb; @@ -186,9 +203,9 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, in return Status::OK(); } -CreateCacheRequest::CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, +CreateCacheRequest::CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz, CreateCacheRequest::CreateCacheFlag flag) - : BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag) { + : BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag), cc_(cc) { // Type has been set already in the base constructor. So we need to fill in the connection info. // On successful return, we will get the connection id rq_.mutable_connection_info()->operator=(cinfo); @@ -209,6 +226,41 @@ Status CreateCacheRequest::Prepare() { } } +Status CreateCacheRequest::PostReply() { + auto p = flatbuffers::GetRoot(reply_.result().data()); + cc_->server_connection_id_ = p->connection_id(); + cc_->cookie_ = p->cookie()->str(); + cc_->client_id_ = p->client_id(); + // Next is a set of cpu id that we should re-adjust ourselves for better affinity. + auto sz = p->cpu_id()->size(); + cc_->cpu_list_.reserve(sz); +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) + std::string c_list; + cpu_set_t cpu_set; + CPU_ZERO(&cpu_set); +#endif + for (auto i = 0; i < sz; ++i) { + auto cpu_id = p->cpu_id()->Get(i); + cc_->cpu_list_.push_back(cpu_id); +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) + c_list += std::to_string(cpu_id) + " "; + CPU_SET(cpu_id, &cpu_set); +#endif + } + +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) + if (sz > 0) { + auto err = sched_setaffinity(getpid(), sizeof(cpu_set), &cpu_set); + if (err == -1) { + RETURN_STATUS_UNEXPECTED("Unable to set affinity. Errno = " + std::to_string(errno)); + } + MS_LOG(WARNING) << "Changing cpu affinity to the following list of cpu id: " + c_list; + } +#endif + + return Status::OK(); +} + Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map &map) { try { flatbuffers::FlatBufferBuilder fbb; @@ -245,6 +297,7 @@ Status GetStatRequest::PostReply() { stat_.num_disk_cached = msg->num_disk_cached(); stat_.num_mem_cached = msg->num_mem_cached(); stat_.avg_cache_sz = msg->avg_cache_sz(); + stat_.num_numa_hit = msg->num_numa_hit(); stat_.max_row_id = msg->max_row_id(); stat_.min_row_id = msg->min_row_id(); stat_.cache_service_state = msg->state(); @@ -255,14 +308,15 @@ Status ListSessionsRequest::PostReply() { auto *msg = flatbuffers::GetRoot(reply_.result().data()); auto session_vector = msg->sessions(); for (auto i = 0; i < session_vector->size(); ++i) { - SessionCacheInfo current_info; - CacheServiceStat stats; + SessionCacheInfo current_info{}; + CacheServiceStat stats{}; auto current_session_info = session_vector->Get(i); current_info.session_id = current_session_info->session_id(); current_info.connection_id = current_session_info->connection_id(); stats.num_mem_cached = current_session_info->stats()->num_mem_cached(); stats.num_disk_cached = current_session_info->stats()->num_disk_cached(); stats.avg_cache_sz = current_session_info->stats()->avg_cache_sz(); + stats.num_numa_hit = current_session_info->stats()->num_numa_hit(); stats.min_row_id = current_session_info->stats()->min_row_id(); stats.max_row_id = current_session_info->stats()->max_row_id(); stats.cache_service_state = current_session_info->stats()->state(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h index 84117f9568..43cd66f852 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h @@ -41,6 +41,7 @@ struct CacheServiceStat { int64_t num_mem_cached; int64_t num_disk_cached; int64_t avg_cache_sz; + int64_t num_numa_hit; row_id_type min_row_id; row_id_type max_row_id; int8_t cache_service_state; @@ -75,6 +76,8 @@ class BaseRequest { kHeartBeat = 14, kToggleWriteMode = 15, kListSessions = 16, + kConnectReset = 17, + kInternalFetchRow = 18, // Add new request before it. kRequestUnknown = 32767 }; @@ -84,10 +87,15 @@ class BaseRequest { friend class CacheClientGreeter; friend class CacheClientRequestTag; friend class CacheClient; + friend class CacheService; /// \brief Base class of a cache server request /// \param type Type of the request - explicit BaseRequest(RequestType type) : type_(type) { rq_.set_type(static_cast(type_)); } + explicit BaseRequest(RequestType type) : type_(type) { + rq_.set_type(static_cast(type_)); + rq_.set_client_id(-1); + rq_.set_flag(0); + } virtual ~BaseRequest() = default; /// \brief A print method for debugging @@ -138,15 +146,7 @@ class CacheRowRequest : public BaseRequest { public: friend class CacheServer; friend class CacheClient; - explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie, bool local_bypass) - : BaseRequest(RequestType::kCacheRow), - support_local_bypass_(local_bypass), - addr_(-1), - sz_(0), - row_id_from_server_(-1) { - rq_.set_connection_id(connection_id); - rq_.add_buf_data(cookie); - } + explicit CacheRowRequest(const CacheClient *cc); ~CacheRowRequest() override = default; /// \brief Serialize a TensorRow for streaming to the cache server @@ -193,7 +193,7 @@ class BatchFetchRequest : public BaseRequest { public: friend class CacheServer; friend class CacheService; - BatchFetchRequest(connection_id_type connection_id, const std::vector &row_id, bool local_bypass); + BatchFetchRequest(const CacheClient *cc, const std::vector &row_id); ~BatchFetchRequest() override = default; Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr); @@ -212,21 +212,18 @@ class CreateCacheRequest : public BaseRequest { /// \param connection_id /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited /// \param flag Attributes of the cache. - explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz, + explicit CreateCacheRequest(CacheClient *cc, const CacheClientInfo &cinfo, uint64_t cache_mem_sz, CreateCacheFlag flag = CreateCacheFlag::kNone); ~CreateCacheRequest() override = default; - void ParseResult(connection_id_type *id, std::string *out) { - auto p = flatbuffers::GetRoot(reply_.result().data()); - *id = p->connection_id(); - *out = p->cookie()->str(); - } - /// Overload the base class Prepare + /// Overload the base class Prepare/PostReply Status Prepare() override; + Status PostReply() override; private: uint64_t cache_mem_sz_; CreateCacheFlag flag_; + CacheClient *cc_; }; /// \brief Request to get all the keys not present at the server. @@ -396,6 +393,23 @@ class ToggleWriteModeRequest : public BaseRequest { } ~ToggleWriteModeRequest() override = default; }; + +class ConnectResetRequest : public BaseRequest { + public: + friend class CacheServer; + explicit ConnectResetRequest(connection_id_type connection_id, int32_t client_id) + : BaseRequest(RequestType::kConnectReset) { + rq_.set_connection_id(connection_id); + rq_.set_client_id(client_id); + } + ~ConnectResetRequest() override = default; + + /// Override the base class function + Status Prepare() override { + CHECK_FAIL_RETURN_UNEXPECTED(rq_.client_id() != -1, "Invalid client id"); + return Status::OK(); + } +}; } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc index c7df81cf89..9bdac9f433 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include "minddata/dataset/core/constants.h" #include "minddata/dataset/engine/cache/cache_ipc.h" #include "minddata/dataset/engine/cache/cache_service.h" @@ -43,36 +44,57 @@ Status CacheServer::DoServiceStart() { MS_LOG(INFO) << "CacheServer will use disk folder: " << top_; } RETURN_IF_NOT_OK(vg_.ServiceStart()); - // There will be num_workers_ threads working on the grpc queue and - // the same number of threads working on the CacheServerRequest queue. + RETURN_IF_NOT_OK(hw_info_->GetNumaNodeInfo()); + auto num_numa_nodes = GetNumaNodeCount(); + // If we link with numa library. Set default memory policy. + // If we don't pin thread to cpu, then use up all memory controllers to maximize + // memory bandwidth. + RETURN_IF_NOT_OK( + CacheServerHW::SetDefaultMemoryPolicy(numa_affinity_ ? CachePoolPolicy::kLocal : CachePoolPolicy::kInterleave)); + auto my_node = hw_info_->GetMyNode(); + MS_LOG(DEBUG) << "Cache server is running on numa node " << my_node; + // Bump up num_workers_ to at least the number of numa nodes + num_workers_ = std::max(num_numa_nodes, num_workers_); + // But also it shouldn't be too many more than the hardware concurrency + auto num_cpus = hw_info_->GetCpuCount(); + num_workers_ = std::min(2 * num_cpus, num_workers_); + // Round up num_workers to a multiple of numa nodes. + auto remainder = num_workers_ % num_numa_nodes; + if (remainder > 0) num_workers_ += (num_numa_nodes - remainder); + MS_LOG(INFO) << "Re-adjusting the number of workers to " << num_workers_; + // There will be some threads working on the grpc queue and + // some number of threads working on the CacheServerRequest queue. // Like a connector object we will set up the same number of queues but // we do not need to preserve any order. We will set the capacity of - // each queue to be 128 since we are just pushing memory pointers which + // each queue to be 64 since we are just pushing memory pointers which // is only 8 byte each. - const int32_t que_capacity = 128; + const int32_t kQueCapacity = 64; // This is the request queue from the client cache_q_ = std::make_shared>(); - cache_q_->Init(num_workers_, que_capacity); + cache_q_->Init(num_workers_, kQueCapacity); + // We will match the number of grpc workers with the number of server workers. + // But technically they don't have to be the same. + num_grpc_workers_ = num_workers_; + MS_LOG(DEBUG) << "Number of gprc workers is set to " << num_grpc_workers_; // For the grpc completion queue to work, we need to allocate some // tags which in our case are instances of CacheServerQuest. // They got recycled and we will allocate them in advance and push // them into some free list. We need more (two or three times) the // size of the cache_q. While each worker is working on a CacheSerRequest, // we need some extra running injecting in the the qrpc completion queue. - const int32_t multiplier = 3; - const int32_t free_list_capacity = multiplier * (que_capacity + 1); + const int32_t kMultiplier = 2; + int ratio = num_workers_ / num_grpc_workers_; + if (num_workers_ % num_grpc_workers_) ++ratio; + const int32_t free_list_capacity = kMultiplier * (kQueCapacity + 1) * ratio; free_list_ = std::make_shared>(); - free_list_->Init(num_workers_, free_list_capacity); - // We need to have a reference to the services memory pool in case - // the Services goes out of scope earlier than us since it is a singleton - mp_ = Services::GetInstance().GetServiceMemPool(); - Allocator alloc(mp_); - tag_.reserve(num_workers_); - // Now we populate all free list. - for (auto m = 0; m < num_workers_; ++m) { - // Ideally we allocate all the free list in one malloc. But it turns out it exceeds the - // Arena size. So we will we will allocate one segment at a time. - auto my_tag = std::make_unique>>(alloc); + free_list_->Init(num_grpc_workers_, free_list_capacity); + tag_.reserve(num_grpc_workers_); + // Now we populate all free list. Round robin the free list among the numa nodes. + for (auto m = 0; m < num_grpc_workers_; ++m) { + NumaAllocator alloc(m % num_numa_nodes, CachePoolPolicy::kPreferred); + // Ideally we allocate all the free list in one malloc. But we will allocate one segment + // at a time so that we can change the numa policy easily per grpc worker. + auto my_tag = std::make_unique>>(alloc); // Allocate the tag and assign it the current queue RETURN_IF_NOT_OK(my_tag->allocate(free_list_capacity, m)); for (int i = 0; i < free_list_capacity; ++i) { @@ -82,11 +104,6 @@ Status CacheServer::DoServiceStart() { } RETURN_IF_NOT_OK(cache_q_->Register(&vg_)); RETURN_IF_NOT_OK(free_list_->Register(&vg_)); - // Spawn a few threads to serve the real request. - auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1); - for (auto i = 0; i < num_workers_; ++i) { - RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i))); - } // Start the comm layer try { comm_layer_ = std::make_shared(port_, shared_memory_sz_in_gb_); @@ -94,10 +111,29 @@ Status CacheServer::DoServiceStart() { } catch (const std::exception &e) { RETURN_STATUS_UNEXPECTED(e.what()); } + // Spawn a few threads to serve the real request. + auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1); + for (auto i = 0; i < num_workers_; ++i) { + Task *pTask; + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache service worker", std::bind(f, i), &pTask)); + // Save a copy of the pointer to the underlying Task object. We may dynamically change their affinity if needed. + numa_tasks_.emplace(i, pTask); + // Spread out all the threads to all the numa nodes if needed + if (IsNumaAffinityOn()) { + auto numa_id = i % num_numa_nodes; + RETURN_IF_NOT_OK(SetAffinity(*pTask, numa_id)); + } + } // Finally loop forever to handle the request. auto r = std::bind(&CacheServer::RpcRequest, this, std::placeholders::_1); - for (auto i = 0; i < num_workers_; ++i) { - RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i))); + for (auto i = 0; i < num_grpc_workers_; ++i) { + Task *pTask; + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("rpc worker", std::bind(r, i), &pTask)); + // All these grpc workers will be allocated to the same node which is where we allocate all those free tag + // memory. + if (IsNumaAffinityOn()) { + RETURN_IF_NOT_OK(SetAffinity(*pTask, i % num_numa_nodes)); + } } return Status::OK(); } @@ -108,8 +144,6 @@ Status CacheServer::DoServiceStop() { // First stop all the threads. RETURN_IF_NOT_OK(vg_.ServiceStop()); // Clean up all the caches if any. - // Dump a message how much memory we have consumed in total. - MS_LOG(INFO) << "Memory usage for the current server: " << GetMemoryUsage() << " bytes."; UniqueLock lck(&rwLock_); auto it = all_caches_.begin(); while (it != all_caches_.end()) { @@ -134,13 +168,14 @@ CacheService *CacheServer::GetService(connection_id_type id) const { Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing connection info"); std::string cookie; + int32_t client_id; auto session_id = rq->connection_info().session_id(); auto crc = rq->connection_info().crc(); // Before allowing the creation, make sure the session had already been created by the user // Our intention is to add this cache to the active sessions list so leave the list locked during // this entire function. - UniqueLock lock(&sessions_lock_); + UniqueLock sess_lck(&sessions_lock_); auto session_it = active_sessions_.find(session_id); if (session_it == active_sessions_.end()) { RETURN_STATUS_UNEXPECTED("A cache creation has been requested but the session was not found!"); @@ -163,6 +198,7 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { } flatbuffers::FlatBufferBuilder fbb; flatbuffers::Offset off_cookie; + flatbuffers::Offset> off_cpu_list; // Before creating the cache, first check if this is a request for a shared usage of an existing cache // If two CreateService come in with identical connection_id, we need to serialize the create. // The first create will be successful and be given a special cookie. @@ -171,32 +207,74 @@ Status CacheServer::CreateService(CacheRequest *rq, CacheReply *reply) { if (global_shutdown_) { return Status::OK(); } + // We would like to protect ourselves from over allocating too much. We will go over existing cache + // and calculate how much we have consumed so far. auto end = all_caches_.end(); - auto it = all_caches_.find(connection_id); + auto it = all_caches_.begin(); bool duplicate = false; + auto avail_mem = CacheServerHW::GetTotalSystemMemory() * memory_cap_ratio_; + int64_t max_avail = avail_mem; + while (it != end) { + if (it->first == connection_id) { + duplicate = true; + break; + } else { + auto &cs = it->second; + CacheService::ServiceStat stat; + RETURN_IF_NOT_OK(cs->GetStat(&stat)); + int64_t mem_consumed = stat.stat_.num_mem_cached * stat.stat_.average_cache_sz; + max_avail -= mem_consumed; + if (max_avail <= 0) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); + } + } + ++it; + } if (it == end) { + // If we have some cache using some memory already, make a reasonable decision if we should return + // out of memory. + if (max_avail < avail_mem) { + int64_t req_mem = cache_mem_sz * 1048576L; // It is in MB unit. + if (req_mem > max_avail) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); + } else if (req_mem == 0) { + // This cache request is specifying unlimited memory up to the memory cap. If we have consumed more than + // 85% of our limit, fail this request. + if (static_cast(max_avail) / static_cast(avail_mem) <= 0.15) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, "Please destroy some sessions"); + } + } + } std::unique_ptr cs; try { cs = std::make_unique(cache_mem_sz, spill ? top_ : "", generate_id); RETURN_IF_NOT_OK(cs->ServiceStart()); cookie = cs->cookie(); + client_id = cs->num_clients_.fetch_add(1); all_caches_.emplace(connection_id, std::move(cs)); } catch (const std::bad_alloc &e) { return Status(StatusCode::kOutOfMemory); } - // Add the cache into the active session tracking. - // We have already validated that the session exists and that this is a new cache created. - session_it->second.insert(connection_id); - } else { duplicate = true; + client_id = it->second->num_clients_.fetch_add(1); MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; } - + // Shuffle the worker threads. But we need to release the locks or we will deadlock when calling + // the following function + lck.Unlock(); + sess_lck.Unlock(); + auto numa_id = client_id % GetNumaNodeCount(); + std::vector cpu_list = hw_info_->GetCpuList(numa_id); + // Send back the data off_cookie = fbb.CreateString(cookie); + off_cpu_list = fbb.CreateVector(cpu_list); CreateCacheReplyMsgBuilder bld(fbb); bld.add_connection_id(connection_id); bld.add_cookie(off_cookie); + bld.add_client_id(client_id); + // The last thing we send back is a set of cpu id that we suggest the client should bind itself to + bld.add_cpu_id(off_cpu_list); auto off = bld.Finish(); fbb.Finish(off); reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); @@ -220,26 +298,8 @@ Status CacheServer::DestroyCache(CacheRequest *rq) { MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; } } - - // Now that this cache is removed, we need to also remove it's connection id from active session tracking - auto session_id = GetSessionID(id); - UniqueLock sess_lck(&sessions_lock_); - - auto it = active_sessions_.find(session_id); - if (it == active_sessions_.end()) { - // The session was not found in the active sessions - RETURN_STATUS_UNEXPECTED("A destroy cache request has been completed but it had a stale session id!"); - } - - auto connection_it = it->second.find(id); - if (connection_it == it->second.end()) { - RETURN_STATUS_UNEXPECTED("A destroy cache request could not find the connection in the activate sessions!"); - } - - // remove that connection id from the set - it->second.erase(connection_it); - MS_LOG(INFO) << "Destroyed cache " << id << " and removed from active session " << session_id; - + // We aren't touching the session list even though we may be dropping the last remaining cache of a session. + // Leave that to be done by the drop session command. return Status::OK(); } @@ -266,6 +326,7 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) { buffers.push_back(rq->buf_data(i).data()); } row_id_type id = -1; + // We will allocate the memory the same numa node this thread is bound to. RETURN_IF_NOT_OK(cs->CacheRow(buffers, &id)); reply->set_result(std::to_string(id)); } else { @@ -301,6 +362,7 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { if (!cs->HasBuildPhase() || cookie == cs->cookie()) { row_id_type id = -1; ReadableSlice src(p, sz); + // We will allocate the memory the same numa node this thread is bound to. rc = cs->FastCacheRow(src, &id); reply->set_result(std::to_string(id)); } else { @@ -330,9 +392,19 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { for (auto i = 0; i < sz; ++i) { row_id.push_back(p->row_id()->Get(i)); } - int64_t mem_sz = 0; - std::vector v; - RETURN_IF_NOT_OK(cs->PreBatchFetch(row_id, &v, &mem_sz)); + std::shared_ptr fbb = std::make_shared(); + RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb)); + auto locator = flatbuffers::GetRoot(fbb->GetBufferPointer()); + int64_t mem_sz = sizeof(int64_t) * (sz + 1); + for (auto i = 0; i < sz; ++i) { + auto row_sz = locator->rows()->Get(i)->size(); + // row_sz is the size of the cached data. Later we will spawn multiple threads + // each of which will copy the data into either shared memory or protobuf concurrently but + // to different region. + // To avoid false sharing, we will bump up row_sz to be a multiple of 4k, i.e. 4096 bytes + row_sz = round_up_4K(row_sz); + mem_sz += row_sz; + } auto client_flag = rq->flag(); bool local_client = BitTest(client_flag, kLocalClientSupport); // For large amount data to be sent back, we will use shared memory provided it is a local @@ -346,7 +418,11 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { void *q = nullptr; RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); WritableSlice dest(q, mem_sz); - RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest)); + Status rc = cs->BatchFetch(fbb, &dest); + if (rc.IsError()) { + shared_pool->Deallocate(q); + return rc; + } // We can't return the absolute address which makes no sense to the client. // Instead we return the difference. auto difference = reinterpret_cast(q) - reinterpret_cast(base); @@ -363,7 +439,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { return Status(StatusCode::kOutOfMemory); } WritableSlice dest(mem.data(), mem_sz); - RETURN_IF_NOT_OK(cs->BatchFetch(row_id, v, &dest)); + RETURN_IF_NOT_OK(cs->BatchFetch(fbb, &dest)); reply->set_result(std::move(mem)); } } @@ -386,6 +462,7 @@ Status CacheServer::GetStat(CacheRequest *rq, CacheReply *reply) { bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); bld.add_avg_cache_sz(svc_stat.stat_.average_cache_sz); + bld.add_num_numa_hit(svc_stat.stat_.num_numa_hit); bld.add_max_row_id(svc_stat.stat_.max_key); bld.add_min_row_id(svc_stat.stat_.min_key); bld.add_state(svc_stat.state_); @@ -506,30 +583,27 @@ Status CacheServer::ToggleWriteMode(CacheRequest *rq) { } Status CacheServer::ListSessions(CacheReply *reply) { - SharedLock lck(&sessions_lock_); - + SharedLock sess_lck(&sessions_lock_); + SharedLock lck(&rwLock_); flatbuffers::FlatBufferBuilder fbb; std::vector> session_msgs_vector; - for (auto it = active_sessions_.begin(); it != active_sessions_.end(); it++) { - session_id_type current_session_id = it->first; - // Loop over each cache inside this session - if (!it->second.empty()) { - for (auto current_conn_id : it->second) { - CacheService *cs = GetService(current_conn_id); - if (cs == nullptr) { - std::string errMsg = "Connection " + std::to_string(current_conn_id) + " not found during list sessions."; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - CacheService::ServiceStat svc_stat; - RETURN_IF_NOT_OK(cs->GetStat(&svc_stat)); - auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached, - svc_stat.stat_.average_cache_sz, svc_stat.stat_.min_key, - svc_stat.stat_.max_key, svc_stat.state_); - auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats); - session_msgs_vector.push_back(current_session_info); - } + for (auto const ¤t_session_id : active_sessions_) { + bool found = false; + for (auto const &it : all_caches_) { + auto current_conn_id = it.first; + if (GetSessionID(current_conn_id) == current_session_id) { + found = true; + auto &cs = it.second; + CacheService::ServiceStat svc_stat; + RETURN_IF_NOT_OK(cs->GetStat(&svc_stat)); + auto current_stats = CreateServiceStatMsg(fbb, svc_stat.stat_.num_mem_cached, svc_stat.stat_.num_disk_cached, + svc_stat.stat_.average_cache_sz, svc_stat.stat_.num_numa_hit, + svc_stat.stat_.min_key, svc_stat.stat_.max_key, svc_stat.state_); + auto current_session_info = CreateListSessionMsg(fbb, current_session_id, current_conn_id, current_stats); + session_msgs_vector.push_back(current_session_info); } - } else { + } + if (!found) { // If there is no cache created yet, assign a connection id of 0 along with empty stats auto current_stats = CreateServiceStatMsg(fbb, 0, 0, 0, 0, 0, 0); auto current_session_info = CreateListSessionMsg(fbb, current_session_id, 0, current_stats); @@ -542,18 +616,35 @@ Status CacheServer::ListSessions(CacheReply *reply) { auto offset = s_builder.Finish(); fbb.Finish(offset); reply->set_result(fbb.GetBufferPointer(), fbb.GetSize()); + return Status::OK(); +} +Status CacheServer::ConnectReset(CacheRequest *rq) { + auto connection_id = rq->connection_id(); + // Hold the shared lock to prevent the cache from being dropped. + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); + if (cs == nullptr) { + std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto client_id = rq->client_id(); + MS_LOG(WARNING) << "Client id " << client_id << " with connection id " << connection_id << " disconnects"; + cs->num_clients_--; + } return Status::OK(); } /// \brief This is the main loop the cache server thread(s) are running. /// Each thread will pop a request and send the result back to the client using grpc /// \return -Status CacheServer::ServerRequest(int32_t worker_id) { +Status CacheServer::ServerRequest(worker_id_t worker_id) { TaskManager::FindMe()->Post(); + MS_LOG(DEBUG) << "Worker id " << worker_id << " is running on node " << hw_info_->GetMyNode(); auto &my_que = cache_q_->operator[](worker_id); // Loop forever until we are interrupted or shutdown. while (!global_shutdown_) { + bool internal_request = false; CacheServerRequest *cache_req = nullptr; RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); auto &rq = cache_req->rq_; @@ -571,8 +662,17 @@ Status CacheServer::ServerRequest(int32_t worker_id) { } break; } - case BaseRequest::RequestType::kBatchFetchRows: { - cache_req->rc_ = BatchFetchRows(&rq, &reply); + case BaseRequest::RequestType::kInternalFetchRow: { + internal_request = true; + auto connection_id = rq.connection_id(); + SharedLock lck(&rwLock_); + CacheService *cs = GetService(connection_id); + if (cs == nullptr) { + std::string errMsg = "Connection " + std::to_string(connection_id) + " not found"; + cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot(rq.buf_data(0).data())); + } break; } case BaseRequest::RequestType::kCreateCache: { @@ -636,6 +736,10 @@ Status CacheServer::ServerRequest(int32_t worker_id) { cache_req->rc_ = ListSessions(&reply); break; } + case BaseRequest::RequestType::kConnectReset: { + cache_req->rc_ = ConnectReset(&rq); + break; + } default: std::string errMsg("Unknown request type : "); errMsg += std::to_string(static_cast(cache_req->type_)); @@ -647,7 +751,13 @@ Status CacheServer::ServerRequest(int32_t worker_id) { // We will re-tag the request back to the grpc queue. Once it comes back from the client, // the CacheServerRequest, i.e. the pointer cache_req, will be free if (!global_shutdown_) { - cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); + if (!internal_request) { + cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); + } else { + // This is an internal request and is not tied to rpc. But need to post because there + // is a thread waiting on the completion of this request. + cache_req->wp_.Set(); + } } } return Status::OK(); @@ -667,12 +777,20 @@ CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int int32_t shared_meory_sz_in_gb, float memory_cap_ratio) : top_(spill_path), num_workers_(num_workers), + num_grpc_workers_(num_workers_), port_(port), shared_memory_sz_in_gb_(shared_meory_sz_in_gb), global_shutdown_(false), memory_cap_ratio_(memory_cap_ratio), - cur_mem_usage_(0) { - memory_cap_ = CacheServer::GetTotalSystemMemory() * memory_cap_ratio_; + numa_affinity_(true) { + hw_info_ = std::make_shared(); + // If we are not linked with numa library (i.e. NUMA_ENABLED is false), turn off cpu + // affinity which can make performance worse. + if (!CacheServerHW::numa_enabled()) { + numa_affinity_ = false; + MS_LOG(WARNING) << "Warning: This build is not compiled with numa support. Install libnuma-devel and use a build " + "that is compiled with numa support for more optimal performance"; + } } Status CacheServer::Run(int msg_qid) { @@ -719,51 +837,52 @@ Status CacheServer::ReturnRequestTag(CacheServerRequest *p) { Status CacheServer::DestroySession(CacheRequest *rq) { CHECK_FAIL_RETURN_UNEXPECTED(rq->has_connection_info(), "Missing session id"); auto drop_session_id = rq->connection_info().session_id(); - - UniqueLock lck(&sessions_lock_); - - // First validate that this session exists - auto it = active_sessions_.find(drop_session_id); - if (it == active_sessions_.end()) { - RETURN_STATUS_UNEXPECTED("A destroy session command has been requested but the session was not found!"); - } - + // Grab the locks in the correct order to avoid deadlock. + UniqueLock sess_lck(&sessions_lock_); + UniqueLock lck(&rwLock_); // Iterate over the set of connection id's for this session that we're dropping and erase each one. - { - UniqueLock rwlck(&rwLock_); - for (auto drop_connection_id : it->second) { - auto cache_drop_it = all_caches_.find(drop_connection_id); - if (cache_drop_it == all_caches_.end()) { - RETURN_STATUS_UNEXPECTED("active session tracking had stale or incorrect cache entry."); - } - all_caches_.erase(cache_drop_it); - MS_LOG(INFO) << "Session destroy: Destroy cache with id " << drop_connection_id; - // **Do not bother to remove the cache connection id from the active session because we will soon remove the - // entire session. + bool found = false; + for (auto it = all_caches_.begin(); it != all_caches_.end();) { + auto connection_id = it->first; + auto session_id = GetSessionID(connection_id); + // We can just call DestroyCache() but we are holding a lock already. Doing so will cause deadlock. + // So we will just manually do it. + if (session_id == drop_session_id) { + found = true; + it = all_caches_.erase(it); + MS_LOG(INFO) << "Destroy cache with id " << connection_id; + } else { + ++it; } } - // Finally remove the session itself - active_sessions_.erase(it); - MS_LOG(INFO) << "Session destroyed with id " << drop_session_id; - - return Status::OK(); + auto n = active_sessions_.erase(drop_session_id); + if (n > 0) { + MS_LOG(INFO) << "Session destroyed with id " << drop_session_id; + return Status::OK(); + } else { + if (found) { + std::string errMsg = + "A destroy cache request has been completed but it had a stale session id " + std::to_string(drop_session_id); + RETURN_STATUS_UNEXPECTED(errMsg); + } else { + std::string errMsg = "Session id " + std::to_string(drop_session_id) + " not found."; + return Status(StatusCode::kFileNotExist, errMsg); + } + } } session_id_type CacheServer::GenerateSessionID() { - UniqueLock lock(&sessions_lock_); + UniqueLock sess_lck(&sessions_lock_); auto mt = GetRandomDevice(); std::uniform_int_distribution distribution(0, std::numeric_limits::max()); session_id_type session_id; bool duplicate = false; do { session_id = distribution(mt); - auto it = active_sessions_.find(session_id); - duplicate = (it != active_sessions_.end()); + auto r = active_sessions_.insert(session_id); + duplicate = !r.second; } while (duplicate); - - // Add this session to our tracking of active sessions with initialized empty set of connections. - active_sessions_[session_id] = std::set(); return session_id; } @@ -789,7 +908,7 @@ Status CacheServer::FreeSharedMemory(CacheRequest *rq) { return Status::OK(); } -Status CacheServer::RpcRequest(int32_t worker_id) { +Status CacheServer::RpcRequest(worker_id_t worker_id) { TaskManager::FindMe()->Post(); RETURN_IF_NOT_OK(comm_layer_->HandleRequest(worker_id)); return Status::OK(); @@ -820,12 +939,22 @@ Status CacheServer::GlobalShutdown() { return Status::OK(); } -int64_t CacheServer::GetTotalSystemMemory() { - auto pages = sysconf(_SC_PHYS_PAGES); - auto page_size = sysconf(_SC_PAGE_SIZE); - auto total = static_cast(pages) * static_cast(page_size); - MS_LOG(INFO) << "Total physical RAM in bytes: " << total; - return total; +worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) { + auto num_numa_nodes = GetNumaNodeCount(); + MS_ASSERT(numa_id < num_numa_nodes); + auto num_workers_per_node = GetNumWorkers() / num_numa_nodes; + std::mt19937 gen = GetRandomDevice(); + std::uniform_int_distribution dist(0, num_workers_per_node - 1); + auto n = dist(gen); + worker_id_t worker_id = n * num_numa_nodes + numa_id; + MS_ASSERT(worker_id < GetNumWorkers()); + return worker_id; +} + +worker_id_t CacheServer::GetRandomWorker() { + std::mt19937 gen = GetRandomDevice(); + std::uniform_int_distribution dist(0, num_workers_ - 1); + return dist(gen); } Status CacheServer::Builder::IpcResourceCleanup() { @@ -842,6 +971,8 @@ Status CacheServer::Builder::IpcResourceCleanup() { rc = mem.Attach(); if (rc.IsError()) { return Status::OK(); + } else { + RETURN_IF_NOT_OK(mem.Detach()); } int32_t num_attached; RETURN_IF_NOT_OK(mem.GetNumAttached(&num_attached)); @@ -892,5 +1023,16 @@ Status CacheServer::Builder::SanityCheck() { RETURN_IF_NOT_OK(IpcResourceCleanup()); return Status::OK(); } + +CacheServer::Builder::Builder() + : top_("/tmp"), + num_workers_(std::thread::hardware_concurrency() / 2), + port_(50052), + shared_memory_sz_in_gb_(4), + memory_cap_ratio_(0.8) { + if (num_workers_ == 0) { + num_workers_ = 1; + } +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h index 631a4abb08..85407d44e1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h @@ -17,23 +17,31 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVER_H_ +#include #include #include #include #include +#include +#include #include #include #include #include #include #include +#include +#include "minddata/dataset/engine/cache/cache_hw.h" +#include "minddata/dataset/engine/cache/cache_numa.h" #include "minddata/dataset/engine/cache/cache_service.h" #include "minddata/dataset/engine/cache/cache_grpc_server.h" +#include "minddata/dataset/engine/cache/cache_pool.h" #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/util/allocator.h" #include "minddata/dataset/util/arena.h" -#include "minddata/dataset/util/cache_pool.h" #include "minddata/dataset/util/lock.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/semaphore.h" #include "minddata/dataset/util/service.h" #include "minddata/dataset/util/services.h" #include "minddata/dataset/util/system_pool.h" @@ -47,9 +55,10 @@ class CacheServer : public Service { public: friend class Services; using cache_index = std::map>; + class Builder { public: - Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4), memory_cap_ratio_(0.8) {} + Builder(); ~Builder() = default; @@ -161,26 +170,40 @@ class CacheServer : public Service { /// \return Status object static Status ReturnRequestTag(CacheServerRequest *p); - /// \brief This returns the size (in bytes) of the physical RAM on the machine. - /// \return the size (in bytes) of the physical RAM on the machine. - static int64_t GetTotalSystemMemory(); + /// Return an instance of the numa control + std::shared_ptr GetHWControl() { return hw_info_; } - /// \brief Internally this is how much we will try to use without exceeding the limit - /// \return Internal cap maximum - int64_t GetAvailableSystemMemory() { return memory_cap_; } + /// \brief Set CPU affinity + Status SetAffinity(const Task &tk, numa_id_t numa_node) { return hw_info_->SetAffinity(tk, numa_node); } - /// \brief Find out the current memory usage - int64_t GetMemoryUsage() { return cur_mem_usage_; } + /// \brief return number of workers + auto GetNumWorkers() const { return num_workers_; } - /// \brief This updates our current memory usage. - enum MemUsageOp : int8_t { kAllocate = 1, kFree = 2 }; - void UpdateMemoryUsage(int64_t sz, MemUsageOp op) { - if (op == MemUsageOp::kAllocate) { - cur_mem_usage_ += sz; - } else { - cur_mem_usage_ -= sz; - } - } + /// \brief return number of grpc workers + auto GetNumGrpcWorkers() const { return num_grpc_workers_; } + + /// \brief return number of numa nodes + auto GetNumaNodeCount() const { return hw_info_->GetNumaNodeCount(); } + + /// \brief Assign a worker by a numa id + /// \return worker id + worker_id_t GetWorkerByNumaId(numa_id_t node_id); + + /// \brief Randomly pick a worker + /// \return worker id + worker_id_t GetRandomWorker(); + + /// \brief Check if we bind threads to numa cores + bool IsNumaAffinityOn() const { return numa_affinity_; } + + /// \brief Internal function to do row batch fetch + /// \param rq Request + /// \param reply Reply + /// \return Status object + Status BatchFetchRows(CacheRequest *rq, CacheReply *reply); + + /// \brief Return the memory cap ratio + float GetMemoryCapRatio() const { return memory_cap_ratio_; } private: static std::once_flag init_instance_flag_; @@ -189,20 +212,21 @@ class CacheServer : public Service { mutable RWLock sessions_lock_; std::string top_; cache_index all_caches_; - std::map> active_sessions_; + std::set active_sessions_; std::shared_ptr> cache_q_; std::shared_ptr> free_list_; - std::vector>>> tag_; + std::vector>>> tag_; std::shared_ptr comm_layer_; - std::shared_ptr mp_; TaskGroup vg_; int32_t num_workers_; + int32_t num_grpc_workers_; int32_t port_; int32_t shared_memory_sz_in_gb_; std::atomic global_shutdown_; float memory_cap_ratio_; - int64_t memory_cap_; - std::atomic cur_mem_usage_; + std::shared_ptr hw_info_; + std::map numa_tasks_; + bool numa_affinity_; /// \brief Constructor /// \param spill_path Top directory for spilling buffers to. @@ -226,11 +250,11 @@ class CacheServer : public Service { Status DestroyCache(CacheRequest *rq); /// \brief Entry point for all internal server threads. - Status ServerRequest(int32_t worker_id); + Status ServerRequest(worker_id_t worker_id); /// \brief Entry point for all grpc threads. /// \return - Status RpcRequest(int32_t worker_id); + Status RpcRequest(worker_id_t worker_id); Status DestroySession(CacheRequest *rq); @@ -266,12 +290,6 @@ class CacheServer : public Service { Status FastCacheRow(CacheRequest *rq, CacheReply *reply); Status CacheRow(CacheRequest *rq, CacheReply *reply); - /// \brief Internal function to do row batch fetch - /// \param rq Request - /// \param reply Reply - /// \return Status object - Status BatchFetchRows(CacheRequest *rq, CacheReply *reply); - /// \brief Internal function to get statistics /// \param rq /// \param reply @@ -309,6 +327,9 @@ class CacheServer : public Service { /// \param reply /// \return Status object Status ListSessions(CacheReply *reply); + + /// \brief Connect request by a pipeline + Status ConnectReset(CacheRequest *rq); }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc index 727f9e736f..624a279ba3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc @@ -13,51 +13,45 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include "minddata/dataset/engine/cache/cache_service.h" #include "minddata/dataset/engine/cache/cache_server.h" +#include "minddata/dataset/engine/cache/cache_numa.h" +#include "minddata/dataset/util/random.h" #include "minddata/dataset/util/slice.h" namespace mindspore { namespace dataset { CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id) : root_(root), - cache_mem_sz_(mem_sz), + cache_mem_sz_(mem_sz * 1048576L), // mem_sz is in MB unit cp_(nullptr), next_id_(0), generate_id_(generate_id), - st_(generate_id ? State::kBuildPhase : State::kNone), - cur_mem_usage_(0), - cur_disk_usage_(0) {} + num_clients_(0), + st_(generate_id ? CacheServiceState::kBuildPhase : CacheServiceState::kNone) {} CacheService::~CacheService() { (void)ServiceStop(); } -bool CacheService::UseArena() { - // If fixed size, use Arena instead of the pool from global context. - return (cache_mem_sz_ > 0); -} - Status CacheService::DoServiceStart() { - std::shared_ptr mp_; CacheServer &cs = CacheServer::GetInstance(); - if (UseArena()) { - auto avail_mem = cs.GetAvailableSystemMemory() / 1048576L; + float memory_cap_ratio = cs.GetMemoryCapRatio(); + if (cache_mem_sz_ > 0) { + auto avail_mem = CacheServerHW::GetTotalSystemMemory(); if (cache_mem_sz_ > avail_mem) { // Output a warning that we use more than recommended. If we fail to allocate, we will fail anyway. - MS_LOG(WARNING) << "Requesting cache size " << cache_mem_sz_ << " MB while available system memory " << avail_mem - << " MB"; + MS_LOG(WARNING) << "Requesting cache size " << cache_mem_sz_ << " while available system memory " << avail_mem; } - // Create a fixed size arena based on the parameter. - std::shared_ptr arena; - RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_)); - mp_ = std::move(arena); - // update the global usage only. - cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kAllocate); - } else { - // Unlimited size. Simply use a system pool. Another choice is CircularPool. - mp_ = std::make_shared(); + memory_cap_ratio = static_cast(cache_mem_sz_) / avail_mem; + } + numa_pool_ = std::make_shared(cs.GetHWControl(), memory_cap_ratio); + // It is possible we aren't able to allocate the pool for many reasons. + std::vector avail_nodes = numa_pool_->GetAvailableNodes(); + if (avail_nodes.empty()) { + RETURN_STATUS_UNEXPECTED("Unable to bring up numa memory pool"); } - // Put together a CachePool for backing up the Tensor - cp_ = std::make_shared(CachePool::value_allocator(mp_), UseArena(), root_); + // Put together a CachePool for backing up the Tensor. + cp_ = std::make_shared(numa_pool_, root_); RETURN_IF_NOT_OK(cp_->ServiceStart()); // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name. cookie_ = cp_->MyName(); @@ -68,26 +62,18 @@ Status CacheService::DoServiceStop() { if (cp_ != nullptr) { RETURN_IF_NOT_OK(cp_->ServiceStop()); } - CacheServer &cs = CacheServer::GetInstance(); - if (UseArena()) { - cs.UpdateMemoryUsage(cache_mem_sz_ * 1048576L, CacheServer::MemUsageOp::kFree); - } else { - MS_LOG(INFO) << "Memory/disk usage for the current service: " << GetMemoryUsage() << " bytes and " << GetDiskUsage() - << " bytes."; - cs.UpdateMemoryUsage(GetMemoryUsage(), CacheServer::MemUsageOp::kFree); - } return Status::OK(); } Status CacheService::CacheRow(const std::vector &buf, row_id_type *row_id_generated) { SharedLock rw(&rw_lock_); RETURN_UNEXPECTED_IF_NULL(row_id_generated); - if (st_ == State::kFetchPhase) { + if (st_ == CacheServiceState::kFetchPhase) { // For this kind of cache service, once we are done with the build phase into fetch phase, we can't // allow other to cache more rows. RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); } - if (st_ == State::kNoLocking) { + if (st_ == CacheServiceState::kNoLocking) { // We ignore write this request once we turn off locking on the B+ tree. So we will just // return out of memory from now on. return Status(StatusCode::kOutOfMemory); @@ -128,26 +114,13 @@ Status CacheService::CacheRow(const std::vector &buf, row_id_type all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i)); total_sz += msg->data_sz()->Get(i); } - // Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it. - // Otherwise, we check how much (globally) how much we use and may simply spill to disk - // directly. - CacheServer &cs = CacheServer::GetInstance(); - bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory(); - Status rc = cp_->Insert(*row_id_generated, all_data, write_to_disk_directly); + // Now we cache the buffer. + Status rc = cp_->Insert(*row_id_generated, all_data); if (rc == Status(StatusCode::kDuplicateKey)) { MS_LOG(DEBUG) << "Ignoring duplicate key."; } else { RETURN_IF_NOT_OK(rc); } - // All good, then update the memory usage local and global (if not using arena) - if (write_to_disk_directly) { - cur_disk_usage_ += total_sz; - } else { - cur_mem_usage_ += total_sz; - if (!UseArena()) { - cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate); - } - } return Status::OK(); } catch (const std::exception &e) { RETURN_STATUS_UNEXPECTED(e.what()); @@ -157,12 +130,12 @@ Status CacheService::CacheRow(const std::vector &buf, row_id_type Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) { SharedLock rw(&rw_lock_); RETURN_UNEXPECTED_IF_NULL(row_id_generated); - if (st_ == State::kFetchPhase) { + if (st_ == CacheServiceState::kFetchPhase) { // For this kind of cache service, once we are done with the build phase into fetch phase, we can't // allow other to cache more rows. RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); } - if (st_ == State::kNoLocking) { + if (st_ == CacheServiceState::kNoLocking) { // We ignore write this request once we turn off locking on the B+ tree. So we will just // return out of memory from now on. return Status(StatusCode::kOutOfMemory); @@ -183,27 +156,13 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_ } *row_id_generated = msg->row_id(); } - // Now we cache the buffer. If we are using Arena which has a fixed cap, then just do it. - // Otherwise, we check how much (globally) how much we use and may simply spill to disk - // directly. - auto total_sz = src.GetSize(); - CacheServer &cs = CacheServer::GetInstance(); - bool write_to_disk_directly = UseArena() ? false : (total_sz + cs.GetMemoryUsage()) > cs.GetAvailableSystemMemory(); - Status rc = cp_->Insert(*row_id_generated, {src}, write_to_disk_directly); + // Now we cache the buffer. + Status rc = cp_->Insert(*row_id_generated, {src}); if (rc == Status(StatusCode::kDuplicateKey)) { MS_LOG(DEBUG) << "Ignoring duplicate key."; } else { RETURN_IF_NOT_OK(rc); } - // All good, then update the memory usage local and global (if not using arena) - if (write_to_disk_directly) { - cur_disk_usage_ += total_sz; - } else { - cur_mem_usage_ += total_sz; - if (!UseArena()) { - cs.UpdateMemoryUsage(total_sz, CacheServer::MemUsageOp::kAllocate); - } - } return Status::OK(); } catch (const std::exception &e) { RETURN_STATUS_UNEXPECTED(e.what()); @@ -247,52 +206,116 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) { return Status::OK(); } -Status CacheService::PreBatchFetch(const std::vector &v, std::vector *out, - int64_t *mem_sz) { +Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector &v, + const std::shared_ptr &fbb) { SharedLock rw(&rw_lock_); - RETURN_UNEXPECTED_IF_NULL(out); - RETURN_UNEXPECTED_IF_NULL(mem_sz); - const auto num_elements = v.size(); - *mem_sz = (num_elements + 1) * sizeof(int64_t); - (*out).reserve(num_elements); + std::vector> datalocator_v; + datalocator_v.reserve(v.size()); for (auto row_id : v) { - auto sz = cp_->GetSize(row_id); - if (sz > 0) { - (*out).emplace_back(row_id, sz); - (*mem_sz) += sz; - } else { - // key not found - (*out).emplace_back(-1, 0); - } + flatbuffers::Offset offset; + RETURN_IF_NOT_OK(cp_->GetDataLocator(row_id, fbb, &offset)); + datalocator_v.push_back(offset); } + auto offset_v = fbb->CreateVector(datalocator_v); + BatchDataLocatorMsgBuilder bld(*fbb); + bld.add_connection_id(connection_id); + bld.add_rows(offset_v); + auto offset_final = bld.Finish(); + fbb->Finish(offset_final); return Status::OK(); } -Status CacheService::BatchFetch(const std::vector &v, const std::vector &info, - WritableSlice *out) const { +Status CacheService::BatchFetch(const std::shared_ptr &fbb, WritableSlice *out) const { RETURN_UNEXPECTED_IF_NULL(out); SharedLock rw(&rw_lock_); - if (st_ == State::kBuildPhase) { + if (st_ == CacheServiceState::kBuildPhase) { // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); } - const auto num_elements = v.size(); + CacheServer &cs = CacheServer::GetInstance(); + int32_t numQ = cs.GetNumGrpcWorkers(); + auto rng = GetRandomDevice(); + std::uniform_int_distribution distribution(0, numQ - 1); + int32_t qID = distribution(rng); + std::vector cache_rq_list; + auto p = flatbuffers::GetRoot(fbb->GetBufferPointer()); + const auto num_elements = p->rows()->size(); + auto connection_id = p->connection_id(); + cache_rq_list.reserve(num_elements); int64_t data_offset = (num_elements + 1) * sizeof(int64_t); auto *offset_array = reinterpret_cast(out->GetMutablePointer()); offset_array[0] = data_offset; for (auto i = 0; i < num_elements; ++i) { - auto sz = info.at(i).second; - offset_array[i + 1] = offset_array[i] + sz; + auto data_locator = p->rows()->Get(i); + auto node_id = data_locator->node_id(); + size_t sz = data_locator->size(); + void *source_addr = reinterpret_cast(data_locator->addr()); + auto key = data_locator->key(); + // Please read the comment in CacheServer::BatchFetchRows where we allocate + // the buffer big enough so each thread (which we are going to dispatch) will + // not run into false sharing problem. We are going to round up sz to 4k. + auto sz_4k = round_up_4K(sz); + offset_array[i + 1] = offset_array[i] + sz_4k; if (sz > 0) { WritableSlice row_data(*out, offset_array[i], sz); - auto key = info.at(i).first; - size_t bytesRead = 0; - RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead)); - if (bytesRead != sz) { - MS_LOG(ERROR) << "Unexpected length. Read " << bytesRead << ". Expected " << sz << "." - << " Internal key: " << key << "\n"; - RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); - } + // Get a request and send to the proper worker (at some numa node) to do the fetch. + worker_id_t worker_id = cs.IsNumaAffinityOn() ? cs.GetWorkerByNumaId(node_id) : cs.GetRandomWorker(); + CacheServerRequest *cache_rq; + RETURN_IF_NOT_OK(cs.GetFreeRequestTag(qID++ % numQ, &cache_rq)); + cache_rq_list.push_back(cache_rq); + // Set up all the necessarily field. + cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow; + cache_rq->st_ = CacheServerRequest::STATE::PROCESS; + cache_rq->rq_.set_connection_id(connection_id); + cache_rq->rq_.set_type(static_cast(cache_rq->type_)); + auto dest_addr = row_data.GetMutablePointer(); + flatbuffers::FlatBufferBuilder fb2; + FetchRowMsgBuilder bld(fb2); + bld.add_key(key); + bld.add_size(sz); + bld.add_source_addr(reinterpret_cast(source_addr)); + bld.add_dest_addr(reinterpret_cast(dest_addr)); + auto offset = bld.Finish(); + fb2.Finish(offset); + cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize()); + RETURN_IF_NOT_OK(cs.PushRequest(worker_id, cache_rq)); + } + } + // Now wait for all of them to come back. Let go of the shared lock. We shouldn't be holding + // any lock while we can wait for a long time. + rw.Unlock(); + Status rc; + for (CacheServerRequest *rq : cache_rq_list) { + RETURN_IF_NOT_OK(rq->Wait()); + if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) { + rc = rq->rc_; + } + RETURN_IF_NOT_OK(cs.ReturnRequestTag(rq)); + } + return rc; +} + +Status CacheService::InternalFetchRow(const FetchRowMsg *p) { + RETURN_UNEXPECTED_IF_NULL(p); + SharedLock rw(&rw_lock_); + size_t bytesRead = 0; + int64_t key = p->key(); + size_t sz = p->size(); + void *source_addr = reinterpret_cast(p->source_addr()); + void *dest_addr = reinterpret_cast(p->dest_addr()); + WritableSlice dest(dest_addr, sz); + if (source_addr != nullptr) { + // We are not checking if the row is still present but simply use the information passed in. + // This saves another tree lookup and is faster. + ReadableSlice src(source_addr, sz); + RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src)); + } else { + RETURN_IF_NOT_OK(cp_->Read(key, &dest, &bytesRead)); + if (bytesRead != sz) { + std::string errMsg = "Unexpected length. Read " + std::to_string(bytesRead) + ". Expected " + std::to_string(sz) + + "." + " Internal key: " + std::to_string(key); + MS_LOG(ERROR) << errMsg; + RETURN_STATUS_UNEXPECTED(errMsg); } } return Status::OK(); @@ -312,7 +335,7 @@ Status CacheService::CacheSchema(const void *buf, int64_t len) { Status CacheService::FetchSchema(std::string *out) const { SharedLock rw(&rw_lock_); - if (st_ == State::kBuildPhase) { + if (st_ == CacheServiceState::kBuildPhase) { // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); } @@ -333,7 +356,7 @@ Status CacheService::BuildPhaseDone() { if (HasBuildPhase()) { // Exclusive lock to switch phase UniqueLock rw(&rw_lock_); - st_ = State::kFetchPhase; + st_ = CacheServiceState::kFetchPhase; cp_->SetLocking(false); return Status::OK(); } else { @@ -348,12 +371,12 @@ Status CacheService::ToggleWriteMode(bool on_off) { } else { // If we stop accepting write request, we turn off locking for the // underlying B+ tree. All future write request we will return kOutOfMemory. - if (st_ == State::kNone && !on_off) { - st_ = State::kNoLocking; + if (st_ == CacheServiceState::kNone && !on_off) { + st_ = CacheServiceState::kNoLocking; cp_->SetLocking(on_off); MS_LOG(WARNING) << "Locking mode is switched off."; - } else if (st_ == State::kNoLocking && on_off) { - st_ = State::kNone; + } else if (st_ == CacheServiceState::kNoLocking && on_off) { + st_ = CacheServiceState::kNone; cp_->SetLocking(on_off); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h index ab8a50775b..474cf526c2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h @@ -29,36 +29,28 @@ #include "minddata/dataset/core/global_context.h" #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/engine/cache/cache_request.h" +#include "minddata/dataset/engine/cache/cache_pool.h" #include "minddata/dataset/util/arena.h" #include "minddata/dataset/util/btree.h" -#include "minddata/dataset/util/cache_pool.h" #include "minddata/dataset/util/service.h" #include "minddata/dataset/util/services.h" #include "minddata/dataset/util/system_pool.h" namespace mindspore { namespace dataset { -/// Some typedef used for BatchFetch -using key_size_pair = std::pair; /// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is /// created to support spilling class CacheService : public Service { public: friend class CacheServer; - enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking }; - /// \brief Constructor /// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited /// \param root Spill path. Empty string means no spilling /// \param generate_id If the cache service should generate row id for buffer that is cached. /// For non-mappable dataset, this should be set to true. CacheService(uint64_t mem_sz, const std::string &root, bool generate_id); - ~CacheService(); - - /// \brief For fixed size memory, we will create an Arena. - /// \return false if unlimited memory. - bool UseArena(); + ~CacheService() override; Status DoServiceStart() override; Status DoServiceStop() override; @@ -77,18 +69,18 @@ class CacheService : public Service { Status FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated); /// \brief This function is used in preparation for batch fetching. - /// It calculates how much memory we should allocate and which row id are present. - /// \param[in/out] Pointer to vector of - /// \param[in/out] mem_sz how much memory is required to batch fetch + /// It calculates how much memory we should allocate and which row id are present, etc. + /// All needed results are stored in the flat buffer. /// \return Status object - Status PreBatchFetch(const std::vector &v, std::vector *, int64_t *mem_sz); + Status PreBatchFetch(connection_id_type connection_id, const std::vector &v, + const std::shared_ptr &); /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. /// \param[in] v A vector of row id. /// \param[out] out A contiguous memory buffer that holds the requested rows. /// \return Status object - Status BatchFetch(const std::vector &v, const std::vector &, WritableSlice *out) const; + Status BatchFetch(const std::shared_ptr &, WritableSlice *out) const; /// \brief Getter function /// \return Spilling path @@ -96,7 +88,7 @@ class CacheService : public Service { /// \brief A structure returned from the cache server for statistics request. class ServiceStat { public: - using state_type = std::underlying_type::type; + using state_type = std::underlying_type::type; ServiceStat() : state_(0) {} ~ServiceStat() = default; CachePool::CacheStat stat_{}; @@ -134,10 +126,6 @@ class CacheService : public Service { /// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call. /// \return Status object Status BuildPhaseDone(); - /// \brief Find out the current memory usage - int64_t GetMemoryUsage() { return cur_mem_usage_; } - /// \brief Find out the current disk usage - int64_t GetDiskUsage() { return cur_disk_usage_; } /// \brief For kToggleWriteMode request Status ToggleWriteMode(bool on_off); @@ -149,14 +137,10 @@ class CacheService : public Service { std::atomic next_id_; bool generate_id_; std::string cookie_; - State st_; + std::atomic num_clients_; + CacheServiceState st_; std::string schema_; - // If we use an Arena, cur_disk_usage is always 0 as we don't know how CachePool manages it. - // Otherwise we track how much is in memory and how much is on disk (if root_ is not empty). - // We use them to control when we should stop caching in memory in the case when there is no - // Arena. - std::atomic cur_mem_usage_; - std::atomic cur_disk_usage_; + std::shared_ptr numa_pool_; // We also cache the result from calling FindKeysMiss because it is expensive. Besides user make // this request after we hit memory full or disk full. So the result is unlikely to change. std::mutex get_key_miss_mux_; @@ -164,6 +148,8 @@ class CacheService : public Service { /// \brief Private function to generate a row id /// \return Row id assigned. row_id_type GetNextRowId() { return next_id_.fetch_add(1); } + + Status InternalFetchRow(const FetchRowMsg *p); }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs b/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs index 5986f379f7..6cee18c905 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs @@ -65,6 +65,7 @@ table ServiceStatMsg { num_mem_cached:int64; num_disk_cached:int64; avg_cache_sz:int64; + num_numa_hit:int64; min_row_id:int64; max_row_id:int64; state:int8; @@ -89,8 +90,10 @@ table CreateCacheRequestMsg { /// Return result of CreateCacheRequest table CreateCacheReplyMsg { - connection_id:int64; + client_id:int32; + connection_id:uint64; cookie:string; + cpu_id:[int32]; } table ListSessionMsg { @@ -102,3 +105,22 @@ table ListSessionMsg { table ListSessionsMsg { sessions:[ListSessionMsg]; } + +table DataLocatorMsg { + key:int64; + node_id:int32; + addr:int64; + size:int64; +} + +table BatchDataLocatorMsg { + connection_id:uint64; + rows:[DataLocatorMsg]; +} + +table FetchRowMsg { + key:int64; + source_addr:int64; + dest_addr:int64; + size:int64; +} diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/CMakeLists.txt new file mode 100644 index 0000000000..c33e2ac9d2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/CMakeLists.txt @@ -0,0 +1,32 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) + +if (ENABLE_CACHE) + ms_protobuf_generate(CACHE_PERF_PROTO_SRCS CACHE_PERF_PROTO_HDRS cache_perf.proto) + + add_executable(cache_perf cache_perf.cc cache_msg.cc cache_perf_run.cc ${CACHE_PERF_PROTO_SRCS}) + target_link_libraries(cache_perf + _c_dataengine + _c_mindrecord + mindspore::protobuf + mindspore_gvar + ${PYTHON_LIBRARIES} + pthread) + + if (USE_GLOG) + target_link_libraries(cache_perf mindspore::glog) + endif () + + add_executable(cache_pipeline cache_pipeline.cc cache_msg.cc cache_pipeline_run.cc ${CACHE_PERF_PROTO_SRCS}) + target_link_libraries(cache_pipeline + _c_dataengine + _c_mindrecord + mindspore::protobuf + mindspore_gvar + ${PYTHON_LIBRARIES} + pthread) + + if (USE_GLOG) + target_link_libraries(cache_pipeline mindspore::glog) + endif () +endif () diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_msg.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_msg.cc new file mode 100644 index 0000000000..9a63788b9d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_msg.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/cache/perf/cache_msg.h" +#include +#include +#include + +namespace mindspore { +namespace dataset { +Status CachePerfMsg::Send(int32_t qID) { + auto err = msgsnd(qID, reinterpret_cast(&small_msg_), sizeof(small_msg_.body.msg), IPC_NOWAIT); + if (err == -1) { + std::string errMsg = "Failed to call msgsnd. Errno = " + std::to_string(errno); + RETURN_STATUS_UNEXPECTED(errMsg); + } + return Status::OK(); +} + +Status CachePerfMsg::Receive(int32_t qID) { + // This is a blocking call. Either there is some message or we the queue is removed when + // the destructor is called. + auto err = msgrcv(qID, reinterpret_cast(&small_msg_), sizeof(small_msg_.body.msg), 0, MSG_NOERROR); + if (err == -1) { + if (errno == EIDRM) { + return Status(StatusCode::kInterrupted); + } else { + std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno); + RETURN_STATUS_UNEXPECTED(errMsg); + } + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_msg.h b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_msg.h new file mode 100644 index 0000000000..b94f715a70 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_msg.h @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_MSG_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_MSG_H_ + +#include +#include +#include +#include "proto/cache_perf.pb.h" +#include "minddata/dataset/engine/cache/cache_common.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// All our messages are very small. So we will use the stack version without the need +// to allocate memory. +struct CacheSmallMsg { + int64_t mtype; + union { + char mtext[1]; + struct { + int32_t type; // the first 4 bytes is the RequestType + int32_t proto_sz; + char proto_buffer[kSharedMessageSize]; + } msg; + } body; +}; +/// A message queue structure between the parent and the child process +class CachePerfMsg { + public: + enum MessageType : int16_t { + kInterrupt = 0, + kEpochResult = 1, + kEpochStart = 2, + kEpochEnd = 3, + kError = 4, + // Add new message before it. + kUnknownMessage = 32767 + }; + CachePerfMsg() : small_msg_{1} { + small_msg_.body.msg.type = kUnknownMessage; + small_msg_.body.msg.proto_sz = 0; + small_msg_.body.msg.proto_buffer[0] = 0; + } + ~CachePerfMsg() = default; + + char *GetMutableBuffer() { return small_msg_.body.msg.proto_buffer; } + + Status Send(int32_t qID); + + void SetType(MessageType requestType) { small_msg_.body.msg.type = requestType; } + void SetProtoBufSz(size_t sz) { small_msg_.body.msg.proto_sz = sz; } + + MessageType GetType() const { return static_cast(small_msg_.body.msg.type); } + size_t GetProtoBufSz() const { return small_msg_.body.msg.proto_sz; } + + Status Receive(int32_t qID); + + private: + CacheSmallMsg small_msg_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_MSG_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf.cc new file mode 100644 index 0000000000..990f1f518d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifdef USE_GLOG +#include +#endif +#include +#include "minddata/dataset/engine/cache/perf/cache_perf_run.h" +namespace ds = mindspore::dataset; + +int main(int argc, char **argv) { +#ifdef USE_GLOG + FLAGS_log_dir = "/tmp"; + google::InitGoogleLogging(argv[0]); +#endif + ds::CachePerfRun cachePerfRun; + if (cachePerfRun.ProcessArgs(argc, argv) == 0) { + std::cout << cachePerfRun << std::endl; + ds::Status rc = cachePerfRun.Run(); + if (rc.IsError()) { + std::cerr << rc.ToString() << std::endl; + } + return static_cast(rc.get_code()); + } + return 0; +} diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf.proto b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf.proto new file mode 100644 index 0000000000..d4a00273a8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf.proto @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +package mindspore.dataset; +option cc_enable_arenas = true; + +message PipelineWorkerEpochSummary { + int32 pipeline = 1; + int32 worker = 2; + int64 min = 3; + int64 max = 4; + int64 avg = 5; + int64 med = 6; + int64 cnt = 7; + int64 elapse = 8; +} + +message EpochDone { + int32 pipeline = 1; +} + +message ErrorMsg { + int32 rc = 1; + string msg = 2; +} diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc new file mode 100644 index 0000000000..acf2b29028 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.cc @@ -0,0 +1,575 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include "minddata/dataset/engine/cache/perf/cache_perf_run.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/sig_handler.h" + +namespace mindspore { +namespace dataset { +const char CachePerfRun::kCachePipelineBinary[] = "cache_pipeline"; +void CachePerfRun::PrintHelp() { + std::cout << "Options:\n" + " -h,--help: Show this usage message\n" + " -s,--num_rows: Set the sample size, i.e., the number of " + "rows\n" + " -r,--row_size: Set the average row size\n" + " -n,--pipeline: Set the number of parallel pieplines. Default = " + << kDftNumOfPipelines + << "\n" + " -e,--epoch: Set the number of epochs. Default = " + << kDftNumberOfEpochs + << "\n" + " --shuffle: Set shuffle=True. Default = " + << std::boolalpha << kDftShuffle + << "\n" + " -p,--prefetch_size: Set the prefetch size for cache. Default = " + << kDftPrefetchSize << "\n" + << " -a,--cache_size: Set cache size. Default = " << kDftCacheSize + << " (Mb)\n" + " --spill: Set spill to disk to True. Default = " + << std::boolalpha << kDftSpill << "\n" + << " -w,--workers: Set the number of parallel workers. Default = " << cfg_.num_parallel_workers() + << "\n" + " --connection: Set number of TCP/IP connections per pipeline. Default = " + << kDftNumConnections << "\n" + << " --port: TCP/IP port of the cache server. Default = " << kCfgDefaultCachePort << "\n" + << " --hostname: Hostname of the cache server. Default = " << kCfgDefaultCacheHost << "\n"; +} + +int32_t CachePerfRun::ProcessArgs(int argc, char **argv) { + if (argc == 1) { + PrintHelp(); + return -1; + } + + const int32_t port_opt = 1000; // there is no short option for port + const int32_t hostname_opt = 1001; // there is no short option for hostname + const int32_t connect_opt = 1002; // there is no short option for connect + + int shuffle = 0; + int spill = 0; + + const char *const short_opts = ":n:e:p:a:s:r:w:"; + const option long_opts[] = {{"pipeline", required_argument, nullptr, 'n'}, + {"epoch", required_argument, nullptr, 'e'}, + {"prefetch_size", required_argument, nullptr, 'p'}, + {"shuffle", no_argument, &shuffle, 1}, + {"cache_size", required_argument, nullptr, 'a'}, + {"num_rows", required_argument, nullptr, 's'}, + {"row_size", required_argument, nullptr, 'r'}, + {"workers", required_argument, nullptr, 'w'}, + {"port", required_argument, nullptr, port_opt}, + {"hostname", required_argument, nullptr, hostname_opt}, + {"spill", no_argument, &spill, 1}, + {"connection", required_argument, nullptr, connect_opt}, + {"help", no_argument, nullptr, 'h'}, + {nullptr, no_argument, nullptr, 0}}; + + std::map seen_opts; + int32_t rc = 0; + try { + while (rc == 0) { + int32_t option_indxex; + const auto opt = getopt_long(argc, argv, short_opts, long_opts, &option_indxex); + + if (-1 == opt) { + if (optind < argc) { + rc = -1; + std::cerr << "Unknown arguments: "; + while (optind < argc) { + std::cerr << argv[optind++] << " "; + } + std::cerr << std::endl; + } + break; + } + + if (opt > 0) { + seen_opts[opt]++; + if (seen_opts[opt] > 1) { + std::string long_name = long_opts[option_indxex].name; + std::cerr << "The " << long_name << " argument was given more than once." << std::endl; + rc = -1; + continue; + } + } + + switch (opt) { + case 0: { + if (long_opts[option_indxex].flag == &shuffle) { + shuffle_ = true; + } else if (long_opts[option_indxex].flag == &spill) { + cache_builder_.SetSpill(true); + } + break; + } + + case 'n': { + num_pipelines_ = std::stoi(optarg); + break; + } + + case 'e': { + num_epoches_ = std::stoi(optarg); + break; + } + + case 'p': { + int32_t prefetch_sz = std::stoi(optarg); + cache_builder_.SetPrefetchSize(prefetch_sz); + break; + } + + case 'a': { + int32_t cache_sz = std::stoi(optarg); + cache_builder_.SetCacheMemSz(cache_sz); + break; + } + + case 's': { + num_rows_ = std::stoi(optarg); + break; + } + + case 'r': { + row_size_ = std::stoi(optarg); + break; + } + + case 'w': { + cfg_.set_num_parallel_workers(std::stoi(optarg)); + break; + } + + case connect_opt: { + int32_t connection_sz = std::stoi(optarg); + cache_builder_.SetNumConnections(connection_sz); + break; + } + + case port_opt: { + int32_t port = std::stoi(optarg); + cache_builder_.SetPort(port); + break; + } + + case hostname_opt: { + std::string hostname = optarg; + cache_builder_.SetHostname(hostname); + break; + } + + case 'h': // -h or --help + PrintHelp(); + rc = -1; + break; + + case ':': + std::cerr << "Missing argument for option " << char(optopt) << std::endl; + rc = -1; + break; + + case '?': // Unrecognized option + default: + std::cerr << "Unknown option " << char(optopt) << std::endl; + PrintHelp(); + rc = -1; + break; + } + } + } catch (const std::exception &e) { + PrintHelp(); + rc = -1; + } + + if (rc < 0) { + return rc; + } + + // We have all the defaults except sample size and average row size which the user must specify. + auto it = seen_opts.find('s'); + if (it == seen_opts.end()) { + std::cerr << "Missing sample size." << std::endl; + return -1; + } + + it = seen_opts.find('r'); + if (it == seen_opts.end()) { + std::cerr << "Missing average row size." << std::endl; + return -1; + } + + if (num_rows_ <= 0) { + std::cerr << "Sample size must be positive." << std::endl; + return -1; + } + + if (row_size_ <= 0) { + std::cerr << "Average row size must be positive." << std::endl; + return -1; + } + + if (num_pipelines_ <= 0) { + std::cerr << "Number of pipelines must be positive." << std::endl; + return -1; + } + + if (num_epoches_ <= 0) { + std::cerr << "Number of epoches must be positive." << std::endl; + return -1; + } + + if (num_rows_ < num_pipelines_) { + std::cerr << "Sample size is smaller than the number of pipelines." << std::endl; + return -1; + } + + pid_lists_.reserve(num_pipelines_); + + return 0; +} + +Status CachePerfRun::GetSession() { + CacheClientGreeter comm(cache_builder_.GetHostname(), cache_builder_.GetPort(), 1); + RETURN_IF_NOT_OK(comm.ServiceStart()); + auto rq = std::make_shared(); + RETURN_IF_NOT_OK(comm.HandleRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); + session_ = rq->GetSessionId(); + std::cout << "Session: " << session_ << std::endl; + cache_builder_.SetSessionId(session_); + return Status::OK(); +} + +CachePerfRun::CachePerfRun() + : my_pipeline_(-1), + num_pipelines_(kDftNumOfPipelines), + num_epoches_(kDftNumberOfEpochs), + num_rows_(0), + row_size_(0), + shuffle_(kDftShuffle), + session_(0), + crc_(0), + epoch_sync_cnt_(0) { + cache_builder_.SetSpill(kDftSpill).SetCacheMemSz(kDftCacheSize); +} + +CachePerfRun::~CachePerfRun() { + if (session_ != 0) { + Status rc; + CacheClientGreeter comm(cache_builder_.GetHostname(), cache_builder_.GetPort(), 1); + rc = comm.ServiceStart(); + if (rc.IsOk()) { + CacheClientInfo cinfo; + cinfo.set_session_id(session_); + auto rq = std::make_shared(cinfo); + rc = comm.HandleRequest(rq); + if (rc.IsOk()) { + rc = rq->Wait(); + if (rc.IsOk()) { + std::cout << "Drop session " << session_ << " successful" << std::endl; + } + } + } + } + // Send an interrupt message to each child. + for (auto msg_qid : msg_send_lists_) { + CachePerfMsg msg; + msg.SetType(CachePerfMsg::MessageType::kInterrupt); + (void)msg.Send(msg_qid); + } + // Wait for each child to return + for (auto pid : pid_lists_) { + int status; + if (waitpid(pid, &status, 0) == -1) { + std::string errMsg = "waitpid fails. errno = " + std::to_string(errno); + std::cerr << errMsg << std::endl; + } else { + MS_LOG(INFO) << "Child pid " << pid << " returns." << std::endl; + } + } + // Remove all the message queues + for (auto msg_qid : msg_send_lists_) { + // Remove the message que and never mind about the return code. + (void)msgctl(msg_qid, IPC_RMID, nullptr); + } + for (auto msg_qid : msg_recv_lists_) { + // Remove the message que and never mind about the return code. + (void)msgctl(msg_qid, IPC_RMID, nullptr); + } +} + +void CachePerfRun::PrintEpochSummary() const { + std::cout << std::setw(12) << "Pipeline #" << std::setw(10) << "worker id" << std::setw(11) << "min (μs)" + << std::setw(11) << "max (μs)" << std::setw(11) << "avg (μs)" << std::setw(14) << "median (μs)" + << std::setw(14) << "buffer count" << std::setw(18) << "Elapsed time (s)" << std::endl; + for (auto &it : epoch_results_) { + auto epoch_worker_summary = it.second; + std::cout << std::setw(12) << epoch_worker_summary.pipeline() + 1 << std::setw(10) << epoch_worker_summary.worker() + << std::setw(10) << epoch_worker_summary.min() << std::setw(10) << epoch_worker_summary.max() + << std::setw(10) << epoch_worker_summary.avg() << std::setw(13) << epoch_worker_summary.med() + << std::setw(14) << epoch_worker_summary.cnt() << std::setw(18) << epoch_worker_summary.elapse() + << std::endl; + } +} + +Status CachePerfRun::ListenToPipeline(int32_t workerId) { + TaskManager::FindMe()->Post(); + int32_t qID = msg_recv_lists_[workerId]; + do { + RETURN_IF_INTERRUPTED(); + CachePerfMsg msg; + RETURN_IF_NOT_OK(msg.Receive(qID)); + // Decode the messages. + auto type = msg.GetType(); + char *p = msg.GetMutableBuffer(); + switch (type) { + case CachePerfMsg::MessageType::kEpochResult: { + PipelineWorkerEpochSummary epoch_worker_summary; + CHECK_FAIL_RETURN_UNEXPECTED(epoch_worker_summary.ParseFromArray(p, msg.GetProtoBufSz()), "Parse fail"); + { + auto pipeline = epoch_worker_summary.pipeline(); + auto worker = epoch_worker_summary.worker(); + std::unique_lock lock(mux_); + // sort by pipeline/worker + auto r = + epoch_results_.emplace(std::pair(pipeline, worker), std::move(epoch_worker_summary)); + CHECK_FAIL_RETURN_UNEXPECTED(r.second, "Insert failed"); + } + break; + } + case CachePerfMsg::MessageType::kEpochEnd: { + EpochDone proto; + CHECK_FAIL_RETURN_UNEXPECTED(proto.ParseFromArray(p, msg.GetProtoBufSz()), "Parse fail"); + auto n = epoch_sync_cnt_.fetch_add(1); + if (n + 1 == num_pipelines_) { + pipeline_wp_.Set(); + } + break; + } + case CachePerfMsg::MessageType::kInterrupt: { + TaskManager::WakeUpWatchDog(); + return Status::OK(); + } + case CachePerfMsg::kError: { + ErrorMsg proto; + CHECK_FAIL_RETURN_UNEXPECTED(proto.ParseFromArray(p, msg.GetProtoBufSz()), "Parse fail"); + return Status(static_cast(proto.rc()), proto.msg()); + } + default: + std::string errMsg = "Unknown request type: " + std::to_string(type); + MS_LOG(ERROR) << errMsg; + RETURN_STATUS_UNEXPECTED(errMsg); + break; + } + } while (true); + return Status::OK(); +} + +Status CachePerfRun::Run() { + // Now we bring up TaskManager. + RETURN_IF_NOT_OK(Services::CreateInstance()); + // Handle Control-C + RegisterHandlers(); + + // Get a session from the server. + RETURN_IF_NOT_OK(GetSession()); + + // Generate a random crc. + auto mt = GetRandomDevice(); + std::uniform_int_distribution distribution(0, std::numeric_limits::max()); + crc_ = distribution(mt); + std::cout << "CRC: " << crc_ << std::endl; + + // Create all the resources required by the pipelines before we fork. + for (auto i = 0; i < num_pipelines_; ++i) { + // We will use shared message queues for communication between parent (this process) + // and each pipelines. + auto access_mode = S_IRUSR | S_IWUSR; + int32_t msg_send_qid = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode); + if (msg_send_qid == -1) { + std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno); + RETURN_STATUS_UNEXPECTED(errMsg); + } + msg_send_lists_.push_back(msg_send_qid); + int32_t msg_recv_qid = msgget(IPC_PRIVATE, IPC_CREAT | IPC_EXCL | access_mode); + if (msg_recv_qid == -1) { + std::string errMsg = "Unable to create a message queue. Errno = " + std::to_string(errno); + RETURN_STATUS_UNEXPECTED(errMsg); + } + msg_recv_lists_.push_back(msg_recv_qid); + } + + // Now we create the children knowing all two sets of message queues are constructed. + for (auto i = 0; i < num_pipelines_; ++i) { + auto pid = fork(); + if (pid == 0) { + // Child. We will call another binary but with different (hidden) parameters. + // The parent process is waiting on a wait post. Any error we hit here we must interrupt the + // parent process + auto interrupt_parent = [this, i]() { + CachePerfMsg msg; + msg.SetType(CachePerfMsg::MessageType::kInterrupt); + msg.Send(msg_recv_lists_[i]); + }; + const std::string self_proc = "/proc/self/exe"; + std::string canonical_path; + canonical_path.resize(400); // PATH_MAX is large. This value should be big enough for our use. + // Some lower level OS library calls are needed here to determine the binary path. + if (realpath(self_proc.data(), canonical_path.data()) == nullptr) { + std::cerr << "Failed to identify cache_perf binary path: " + std::to_string(errno) << ": " << strerror(errno) + << std::endl; + interrupt_parent(); + // Call _exit instead of exit because we will hang in TaskManager destructor for a forked child process. + _exit(-1); + } + canonical_path.resize(strlen(canonical_path.data())); + int last_seperator = canonical_path.find_last_of('/'); + if (last_seperator == std::string::npos) { + std::cerr << "Canonical path can't locate / " << canonical_path << std::endl; + interrupt_parent(); + // Call _exit instead of exit because we will hang in TaskManager destructor for a forked child process. + _exit(-1); + } + // truncate the binary name so we are left with the absolute path of cache_admin binary + canonical_path.resize(last_seperator + 1); + std::string cache_pipeline_binary = canonical_path + std::string(kCachePipelineBinary); + + std::string pipeline_cfg = std::to_string(i) + "," + std::to_string(session_) + "," + std::to_string(crc_) + "," + + std::to_string(msg_send_lists_[i]) + "," + std::to_string(msg_recv_lists_[i]) + "," + + std::to_string(num_pipelines_) + "," + std::to_string(num_epoches_) + "," + + std::to_string(num_rows_) + "," + std::to_string(row_size_) + "," + + std::to_string(cfg_.num_parallel_workers()) + "," + + (shuffle_ ? std::string("true").data() : std::string("false").data()); + std::string client_cfg = cache_builder_.GetHostname() + "," + std::to_string(cache_builder_.GetPort()) + "," + + std::to_string(cache_builder_.GetPrefetchSize()) + "," + + std::to_string(cache_builder_.GetCacheMemSz()) + "," + + std::to_string(cache_builder_.GetNumConnections()) + "," + + (cache_builder_.isSpill() ? std::string("true").data() : std::string("false").data()); + char *argv[4]; + argv[0] = const_cast(kCachePipelineBinary); + argv[1] = pipeline_cfg.data(); + argv[2] = client_cfg.data(); + argv[3] = nullptr; + // Invoke the binary. + execv(cache_pipeline_binary.data(), argv); + std::cerr << "Unable to exec. Errno = " + std::to_string(errno) << ": " << strerror(errno) << std::endl; + interrupt_parent(); + // Call _exit instead of exit because we will hang TaskManager destructor for a forked child process. + _exit(-1); + } else if (pid > 0) { + std::cout << "Pipeline number " << i + 1 << " has been created with process id: " << pid << std::endl; + pid_lists_.push_back(pid); + } else { + std::string errMsg = "Failed to fork process for cache pipeline: " + std::to_string(errno); + RETURN_STATUS_UNEXPECTED(errMsg); + } + } + + // Spawn a few threads to monitor the communications from the pipeline. + RETURN_IF_NOT_OK(vg_.ServiceStart()); + + auto f = std::bind(&CachePerfRun::ListenToPipeline, this, std::placeholders::_1); + for (auto i = 0; i < num_pipelines_; ++i) { + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Queue listener", std::bind(f, i))); + } + + // Wait until all pipelines finish the first epoch. + RETURN_IF_NOT_OK(pipeline_wp_.Wait()); + + std::cout << "Epoch one (build phase) per pipeline per worker summary. Buffer size = " << cfg_.rows_per_buffer() + << std::endl; + PrintEpochSummary(); + + // Get some stat but we need to connect. The server will thinks it is the (n+1) pipeline + RETURN_IF_NOT_OK(cache_builder_.Build(&cc_)); + Status rc = cc_->CreateCache(crc_, false); + // Duplicate key is fine. + if (rc.IsError() && rc.get_code() != StatusCode::kDuplicateKey) { + return rc; + } + + CacheServiceStat stat{}; + RETURN_IF_NOT_OK(cc_->GetStat(&stat)); + + std::cout << "Get statistics for this session:\n"; + std::cout << std::setw(12) << "Mem cached" << std::setw(12) << "Disk cached" << std::setw(16) << "Avg cache size" + << std::setw(10) << "Numa hit" << std::endl; + std::string stat_mem_cached; + std::string stat_disk_cached; + std::string stat_avg_cached; + std::string stat_numa_hit; + stat_mem_cached = (stat.num_mem_cached == 0) ? "n/a" : std::to_string(stat.num_mem_cached); + stat_disk_cached = (stat.num_disk_cached == 0) ? "n/a" : std::to_string(stat.num_disk_cached); + stat_avg_cached = (stat.avg_cache_sz == 0) ? "n/a" : std::to_string(stat.avg_cache_sz); + stat_numa_hit = (stat.num_numa_hit == 0) ? "n/a" : std::to_string(stat.num_numa_hit); + + std::cout << std::setw(12) << stat_mem_cached << std::setw(12) << stat_disk_cached << std::setw(16) << stat_avg_cached + << std::setw(10) << stat_numa_hit << std::endl; + + // Toggle write mode off since the rest are just read only. + // Simplest way is call this special internal function. + cc_->ServerRunningOutOfResources(); + + // The rest of the epochs are just fetching. + auto epoch_num = 2; + while (epoch_num <= num_epoches_) { + epoch_sync_cnt_ = 0; + pipeline_wp_.Clear(); + epoch_results_.clear(); + // Signal each pipeline to start + for (auto msg_qid : msg_send_lists_) { + CachePerfMsg msg; + msg.SetType(CachePerfMsg::MessageType::kEpochStart); + (void)msg.Send(msg_qid); + } + // Wait for the child to finish + RETURN_IF_NOT_OK(pipeline_wp_.Wait()); + std::cout << "Epoch " << epoch_num + << " (read phase) per pipeline per worker summary. Buffer size = " << cc_->GetPrefetchSize() << std::endl; + PrintEpochSummary(); + ++epoch_num; + } + + // Destroy the cache. We no longer need it around. + RETURN_IF_NOT_OK(cc_->DestroyCache()); + + // Unreserve the session + CacheClientInfo cinfo; + cinfo.set_session_id(session_); + auto rq = std::make_shared(cinfo); + RETURN_IF_NOT_OK(cc_->PushRequest(rq)); + RETURN_IF_NOT_OK(rq->Wait()); + std::cout << "Drop session " << session_ << " successful" << std::endl; + session_ = 0; + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.h b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.h new file mode 100644 index 0000000000..8b21954bb8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_perf_run.h @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/cache/perf/cache_msg.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { + +constexpr int32_t kDftNumOfPipelines = 8; +constexpr int32_t kDftNumberOfEpochs = 10; +constexpr int32_t kDftCacheSize = 0; +constexpr bool kDftShuffle = false; +constexpr bool kDftSpill = false; + +class CachePerfRun { + public: + static const char kCachePipelineBinary[]; + CachePerfRun(); + ~CachePerfRun(); + void PrintHelp(); + int32_t ProcessArgs(int argc, char **argv); + + void Print(std::ostream &out) const { + out << "Number of pipelines: " << num_pipelines_ << "\n" + << "Number of epochs: " << num_epoches_ << "\n" + << "Sample size: " << num_rows_ << "\n" + << "Average row size: " << row_size_ << "\n" + << "Shuffle: " << std::boolalpha << shuffle_; + } + + friend std::ostream &operator<<(std::ostream &out, const CachePerfRun &cp) { + cp.Print(out); + return out; + } + + Status Run(); + + private: + std::mutex mux_; + int32_t my_pipeline_; + int32_t num_pipelines_; + int32_t num_epoches_; + int64_t num_rows_; + int32_t row_size_; + bool shuffle_; + CacheClient::Builder cache_builder_; + session_id_type session_; + int32_t crc_; + std::vector pid_lists_; + std::vector msg_send_lists_; + std::vector msg_recv_lists_; + TaskGroup vg_; + std::atomic epoch_sync_cnt_; + WaitPost pipeline_wp_; + std::map, PipelineWorkerEpochSummary> epoch_results_; + ConfigManager cfg_; + std::shared_ptr cc_; + + Status GetSession(); + Status ListenToPipeline(int32_t workerId); + void PrintEpochSummary() const; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PERF_RUN_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline.cc new file mode 100644 index 0000000000..130bc102e6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifdef USE_GLOG +#include +#endif +#include +#include "minddata/dataset/engine/cache/perf/cache_pipeline_run.h" +namespace ds = mindspore::dataset; + +int main(int argc, char **argv) { +#ifdef USE_GLOG + FLAGS_log_dir = "/tmp"; + FLAGS_minloglevel = google::WARNING; + google::InitGoogleLogging(argv[0]); +#endif + ds::CachePipelineRun cachePipelineRun; + if (cachePipelineRun.ProcessArgs(argc, argv) == 0) { + ds::Status rc = cachePipelineRun.Run(); + // If we hit any error, send the rc back to the parent. + if (rc.IsError()) { + ds::ErrorMsg proto; + proto.set_rc(static_cast(rc.get_code())); + proto.set_msg(rc.ToString()); + ds::CachePerfMsg msg; + (void)cachePipelineRun.SendMessage(&msg, ds::CachePerfMsg::MessageType::kError, &proto); + } + return static_cast(rc.get_code()); + } + return 0; +} diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc new file mode 100644 index 0000000000..280bc08052 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.cc @@ -0,0 +1,471 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include "minddata/dataset/engine/cache/perf/cache_pipeline_run.h" +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/services.h" + +namespace mindspore { +namespace dataset { +void CachePipelineRun::PrintHelp() { std::cout << "Please run the executable cache_perf instead." << std::endl; } + +int32_t CachePipelineRun::ProcessArgs(int argc, char **argv) { + if (argc != 3) { + PrintHelp(); + return -1; + } + + try { + std::stringstream cfg_ss(argv[1]); + std::string s; + int32_t numArgs = 0; + while (std::getline(cfg_ss, s, ',')) { + if (numArgs == 0) { + my_pipeline_ = std::stoi(s); + } else if (numArgs == 1) { + session_ = std::stoul(s); + cache_builder_.SetSessionId(session_); + } else if (numArgs == 2) { + crc_ = std::stoi(s); + } else if (numArgs == 3) { + recv_id_ = std::stoi(s); + } else if (numArgs == 4) { + send_id_ = std::stoi(s); + } else if (numArgs == 5) { + num_pipelines_ = std::stoi(s); + } else if (numArgs == 6) { + num_epoches_ = std::stoi(s); + } else if (numArgs == 7) { + num_rows_ = std::stol(s); + } else if (numArgs == 8) { + row_size_ = std::stoi(s); + } else if (numArgs == 9) { + cfg_.set_num_parallel_workers(std::stol(s)); + } else if (numArgs == 10) { + shuffle_ = strcmp(s.data(), "true") == 0; + } + ++numArgs; + } + if (numArgs != 11) { + std::cerr << "Incomplete arguments. Expect 11. But get " << numArgs << std::endl; + return -1; + } + std::stringstream client_ss(argv[2]); + numArgs = 0; + while (std::getline(client_ss, s, ',')) { + if (numArgs == 0) { + cache_builder_.SetHostname(s); + } else if (numArgs == 1) { + cache_builder_.SetPort(std::stoi(s)); + } else if (numArgs == 2) { + cache_builder_.SetPrefetchSize(std::stoi(s)); + } else if (numArgs == 3) { + cache_builder_.SetCacheMemSz(std::stoi(s)); + } else if (numArgs == 4) { + cache_builder_.SetNumConnections(std::stoi(s)); + } else if (numArgs == 5) { + cache_builder_.SetSpill(strcmp(s.data(), "true") == 0); + } + ++numArgs; + } + if (numArgs != 6) { + std::cerr << "Incomplete arguments. Expect 6. But get " << numArgs << std::endl; + return -1; + } + } catch (const std::exception &e) { + std::cerr << "Parse error: " << e.what() << std::endl; + return -1; + } + return 0; +} + +CachePipelineRun::CachePipelineRun() + : my_pipeline_(-1), + num_pipelines_(kDftNumOfPipelines), + num_epoches_(kDftNumberOfEpochs), + num_rows_(0), + row_size_(0), + shuffle_(kDftShuffle), + session_(0), + crc_(0), + send_id_(-1), + recv_id_(-1), + start_row_(-1), + end_row_(-1) { + cache_builder_.SetSpill(kDftSpill).SetCacheMemSz(kDftCacheSize); +} + +CachePipelineRun::~CachePipelineRun() { + CachePerfMsg msg; + (void)SendMessage(&msg, CachePerfMsg::MessageType::kInterrupt, nullptr); +} + +Status CachePipelineRun::ListenToParent() { + TaskManager::FindMe()->Post(); + do { + RETURN_IF_INTERRUPTED(); + CachePerfMsg msg; + RETURN_IF_NOT_OK(msg.Receive(recv_id_)); + // Decode the messages. + auto type = msg.GetType(); + switch (type) { + case CachePerfMsg::MessageType::kInterrupt: { + TaskManager::WakeUpWatchDog(); + return Status::OK(); + } + case CachePerfMsg::MessageType::kEpochStart: { + pipeline_wp_.Set(); + break; + } + default: + std::string errMsg = "Unknown request type: " + std::to_string(type); + MS_LOG(ERROR) << errMsg; + RETURN_STATUS_UNEXPECTED(errMsg); + break; + } + } while (true); + + return Status::OK(); +} + +Status CachePipelineRun::Run() { + RETURN_IF_NOT_OK(cache_builder_.Build(&cc_)); + RETURN_IF_NOT_OK(vg_.ServiceStart()); + + auto num_workers = cfg_.num_parallel_workers(); + io_block_queues_.Init(num_workers, cfg_.op_connector_size()); + + RETURN_IF_NOT_OK(io_block_queues_.Register(&vg_)); + + Status rc = cc_->CreateCache(crc_, false); + // Duplicate key is fine. + if (rc.IsError() && rc.get_code() != StatusCode::kDuplicateKey) { + return rc; + } + + // Log a warning level message so we can see it in the log file when it starts. + MS_LOG(WARNING) << "Pipeline number " << my_pipeline_ + 1 << " successfully creating cache service." << std::endl; + + // Spawn a thread to listen to the parent process + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Queue listener", std::bind(&CachePipelineRun::ListenToParent, this))); + + RETURN_IF_NOT_OK(RunFirstEpoch()); + + // The rest of the epochs are just fetching. + auto remaining_epochs = num_epoches_ - 1; + while (remaining_epochs > 0) { + // Wait for parent process signal to start + pipeline_wp_.Wait(); + pipeline_wp_.Clear(); + RETURN_IF_NOT_OK(RunReadEpoch()); + --remaining_epochs; + } + + // The listener thread is blocked on a shared message queue. It will be waken up by + // the parent process which will send an interrupt message to it, and this program + // will exit. + RETURN_IF_NOT_OK(vg_.join_all()); + return Status::OK(); +} + +Status CachePipelineRun::RunFirstEpoch() { + auto sz = num_rows_ / num_pipelines_; + start_row_ = my_pipeline_ * sz; + end_row_ = (my_pipeline_ + 1) * sz - 1; + if (my_pipeline_ + 1 == num_pipelines_) { + end_row_ = num_rows_ - 1; + } + std::cout << "Pipeline number " << my_pipeline_ + 1 << " row id range: [" << start_row_ << "," << end_row_ << "]" + << std::endl; + + // Spawn the worker threads. + auto f = std::bind(&CachePipelineRun::WriterWorkerEntry, this, std::placeholders::_1); + std::vector worker_threads; + auto num_workers = cfg_.num_parallel_workers(); + worker_threads.reserve(num_workers); + for (int32_t i = 0; i < num_workers; ++i) { + Task *pTask; + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Parallel worker", std::bind(f, i), &pTask)); + worker_threads.push_back(pTask); + } + + std::vector keys; + auto rows_per_buffer = cfg_.rows_per_buffer(); + keys.reserve(rows_per_buffer); + int32_t worker_id = 0; + for (auto i = start_row_; i <= end_row_; ++i) { + keys.push_back(i); + if (keys.size() == rows_per_buffer) { + auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); + RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); + keys.clear(); + } + } + if (!keys.empty()) { + auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); + RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); + keys.clear(); + } + + // Shutdown threads and wait for them to come back. + for (int32_t i = 0; i < num_workers; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + for (auto *pTask : worker_threads) { + RETURN_IF_NOT_OK(pTask->Join(Task::WaitFlag::kBlocking)); + } + + // Send a message saying epoch one done for this pipeline. + EpochDone proto; + proto.set_pipeline(my_pipeline_); + CachePerfMsg msg; + RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochEnd, &proto)); + + return Status::OK(); +} + +Status CachePipelineRun::WriterWorkerEntry(int32_t worker_id) { + Status rc; + TaskManager::FindMe()->Post(); + int64_t min_val = std::numeric_limits::max(); + int64_t max_val = 0; + int64_t total_val = 0; + int64_t cnt = 0; + std::vector duration; + duration.reserve(num_rows_ / num_pipelines_ / cfg_.num_parallel_workers()); + bool resource_err = false; + auto col_desc = std::make_unique("int64", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1); + auto num_elements = row_size_ / sizeof(int64_t); + TensorShape shape(std::vector(1, num_elements)); + std::unique_ptr blk; + auto epoch_start = std::chrono::steady_clock::now(); + do { + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); + std::vector keys; + RETURN_IF_NOT_OK(blk->GetKeys(&keys)); + if (keys.empty()) { + // empty key is a quit signal for workers + break; + } + // Once we hit resource error, we drain the io block. No point to send anything to the server. + if (!resource_err) { + auto buffer = std::make_unique(cnt++, DataBuffer::kDeBFlagNone); + auto tensor_table = std::make_unique(); + for (auto id : keys) { + TensorRow row; + std::shared_ptr element; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, col_desc->type(), &element)); + row.setId(id); + // CreateEmpty allocates the memory but in virutal address. Let's commit the memory + // so we can get an accurate timing. + auto it = element->begin(); + for (auto i = 0; i < num_elements; ++i, ++it) { + *it = i; + } + row.push_back(std::move(element)); + tensor_table->push_back(std::move(row)); + } + buffer->set_tensor_table(std::move(tensor_table)); + // Measure the time to call WriteBuffer + auto start_tick = std::chrono::steady_clock::now(); + rc = cc_->WriteBuffer(std::move(buffer)); + auto end_tick = std::chrono::steady_clock::now(); + if (rc.IsError()) { + if (rc.IsOutofMemory() || rc.IsNoSpace()) { + MS_LOG(WARNING) << "Pipeline number " << my_pipeline_ + 1 << " worker id " << worker_id << ": " + << rc.ToString(); + resource_err = true; + cc_->ServerRunningOutOfResources(); + continue; + } else { + return rc; + } + } else { + int64_t ms = std::chrono::duration_cast(end_tick - start_tick).count(); + min_val = std::min(min_val, ms); + max_val = std::max(max_val, ms); + duration.push_back(ms); + total_val += ms; + } + } + } while (true); + + auto epoch_end = std::chrono::steady_clock::now(); + int64_t elapse_time = std::chrono::duration_cast(epoch_end - epoch_start).count(); + + PipelineWorkerEpochSummary proto; + proto.set_pipeline(my_pipeline_); + proto.set_worker(worker_id); + proto.set_min(min_val); + proto.set_max(max_val); + proto.set_elapse(elapse_time); + auto sz = duration.size(); + proto.set_cnt(sz); + if (sz > 0) { + // median + auto n = sz / 2; + std::nth_element(duration.begin(), duration.begin() + n, duration.end()); + auto median = duration[n]; + proto.set_med(median); + // average + int64_t avg = total_val / sz; + proto.set_avg(avg); + } + CachePerfMsg msg; + RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochResult, &proto)); + return Status::OK(); +} + +Status CachePipelineRun::RunReadEpoch() { + std::vector keys; + auto rows_per_buffer = cc_->GetPrefetchSize(); // We will use prefetch size to read. + auto num_workers = cfg_.num_parallel_workers(); + keys.reserve(rows_per_buffer); + // Spawn workers + auto f = std::bind(&CachePipelineRun::ReaderWorkerEntry, this, std::placeholders::_1); + std::vector worker_threads; + worker_threads.reserve(num_workers); + for (int32_t i = 0; i < num_workers; ++i) { + Task *pTask; + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Parallel worker", std::bind(f, i), &pTask)); + worker_threads.push_back(pTask); + } + + std::vector all_keys; + all_keys.reserve(end_row_ - start_row_ + 1); + for (auto i = start_row_; i <= end_row_; ++i) { + all_keys.push_back((i)); + } + // If we need to shuffle the keys + if (shuffle_) { + std::shuffle(all_keys.begin(), all_keys.end(), GetRandomDevice()); + } + + int32_t worker_id = 0; + for (auto id : all_keys) { + keys.push_back(id); + if (keys.size() == rows_per_buffer) { + auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); + RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); + keys.clear(); + } + } + if (!keys.empty()) { + auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); + RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); + keys.clear(); + } + + // Shutdown threads and wait for them to come back. + for (int32_t i = 0; i < num_workers; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + for (auto *pTask : worker_threads) { + RETURN_IF_NOT_OK(pTask->Join(Task::WaitFlag::kBlocking)); + } + + // Send a message saying epoch one done for this pipeline. + EpochDone proto; + proto.set_pipeline(my_pipeline_); + CachePerfMsg msg; + RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochEnd, &proto)); + return Status::OK(); +} + +Status CachePipelineRun::ReaderWorkerEntry(int32_t worker_id) { + Status rc; + TaskManager::FindMe()->Post(); + int64_t min_val = std::numeric_limits::max(); + int64_t max_val = 0; + int64_t total_val = 0; + int64_t cnt = 0; + std::vector duration; + duration.reserve(num_rows_ / num_pipelines_ / cfg_.num_parallel_workers()); + std::unique_ptr blk; + auto epoch_start = std::chrono::steady_clock::now(); + do { + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); + std::vector keys; + RETURN_IF_NOT_OK(blk->GetKeys(&keys)); + if (keys.empty()) { + // empty key is a quit signal for workers + break; + } + std::vector prefetch_keys; + prefetch_keys.reserve(keys.size()); + + // Filter out all those keys that unlikely we will find at the server + for (auto row_id : keys) { + if (!cc_->KeyIsCacheMiss(row_id)) { + prefetch_keys.push_back(row_id); + } + } + // Early exit if nothing to fetch + if (prefetch_keys.empty()) { + continue; + } + // Get the rows from the server + TensorTable ttbl; + // Measure how long it takes for the row to come back. + auto start_tick = std::chrono::steady_clock::now(); + RETURN_IF_NOT_OK(cc_->GetRows(prefetch_keys, &ttbl)); + auto end_tick = std::chrono::steady_clock::now(); + int64_t ms = std::chrono::duration_cast(end_tick - start_tick).count(); + min_val = std::min(min_val, ms); + max_val = std::max(max_val, ms); + duration.push_back(ms); + total_val += ms; + ++cnt; + } while (true); + + auto epoch_end = std::chrono::steady_clock::now(); + int64_t elapse_time = std::chrono::duration_cast(epoch_end - epoch_start).count(); + + PipelineWorkerEpochSummary proto; + proto.set_pipeline(my_pipeline_); + proto.set_worker(worker_id); + proto.set_min(min_val); + proto.set_max(max_val); + proto.set_elapse(elapse_time); + auto sz = duration.size(); + proto.set_cnt(sz); + if (sz > 0) { + // median + auto n = sz / 2; + std::nth_element(duration.begin(), duration.begin() + n, duration.end()); + auto median = duration[n]; + proto.set_med(median); + // average + int64_t avg = total_val / sz; + proto.set_avg(avg); + } + CachePerfMsg msg; + RETURN_IF_NOT_OK(SendMessage(&msg, CachePerfMsg::MessageType::kEpochResult, &proto)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.h b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.h new file mode 100644 index 0000000000..d13aec76de --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/perf/cache_pipeline_run.h @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/cache/perf/cache_msg.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { + +constexpr int32_t kDftNumOfPipelines = 8; +constexpr int32_t kDftNumberOfEpochs = 10; +constexpr int32_t kDftCacheSize = 0; +constexpr bool kDftShuffle = false; +constexpr bool kDftSpill = false; + +class CachePipelineRun { + public: + CachePipelineRun(); + ~CachePipelineRun(); + static void PrintHelp(); + int32_t ProcessArgs(int argc, char **argv); + + void Print(std::ostream &out) const { + out << "Number of pipelines: " << num_pipelines_ << "\n" + << "Number of epochs: " << num_epoches_ << "\n" + << "Sample size: " << num_rows_ << "\n" + << "Average row size: " << row_size_ << "\n" + << "Shuffle: " << std::boolalpha << shuffle_; + } + + friend std::ostream &operator<<(std::ostream &out, const CachePipelineRun &cp) { + cp.Print(out); + return out; + } + + Status Run(); + + template + Status SendMessage(CachePerfMsg *msg, CachePerfMsg::MessageType type, T *proto) { + RETURN_UNEXPECTED_IF_NULL(msg); + msg->SetType(type); + if (proto != nullptr) { + auto size_needed = proto->ByteSizeLong(); + CHECK_FAIL_RETURN_UNEXPECTED( + size_needed <= kSharedMessageSize, + "Shared message set too small. Suggest to increase to " + std::to_string(size_needed)); + CHECK_FAIL_RETURN_UNEXPECTED(proto->SerializeToArray(msg->GetMutableBuffer(), kSharedMessageSize), + "Serialization fails"); + msg->SetProtoBufSz(size_needed); + } + RETURN_IF_NOT_OK(msg->Send(send_id_)); + return Status::OK(); + } + + private: + int32_t my_pipeline_; + int32_t num_pipelines_; + int32_t num_epoches_; + int64_t num_rows_; + int32_t row_size_; + bool shuffle_; + CacheClient::Builder cache_builder_; + session_id_type session_; + int32_t crc_; + TaskGroup vg_; + WaitPost pipeline_wp_; + int32_t send_id_; + int32_t recv_id_; + row_id_type start_row_; + row_id_type end_row_; + ConfigManager cfg_; + QueueList> io_block_queues_; // queues of IOBlocks + std::shared_ptr cc_; + + Status ListenToParent(); + Status RunFirstEpoch(); + Status RunReadEpoch(); + Status WriterWorkerEntry(int32_t worker_id); + Status ReaderWorkerEntry(int32_t worker_id); +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_PIPELINE_RUN_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_container.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/storage_container.cc similarity index 98% rename from mindspore/ccsrc/minddata/dataset/util/storage_container.cc rename to mindspore/ccsrc/minddata/dataset/engine/cache/storage_container.cc index cca10b40c4..d40a007bb2 100644 --- a/mindspore/ccsrc/minddata/dataset/util/storage_container.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/storage_container.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "minddata/dataset/util/storage_container.h" +#include "minddata/dataset/engine/cache/storage_container.h" #include #include diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_container.h b/mindspore/ccsrc/minddata/dataset/engine/cache/storage_container.h similarity index 100% rename from mindspore/ccsrc/minddata/dataset/util/storage_container.h rename to mindspore/ccsrc/minddata/dataset/engine/cache/storage_container.h diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.cc similarity index 98% rename from mindspore/ccsrc/minddata/dataset/util/storage_manager.cc rename to mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.cc index bb70198fd1..2e16e843f5 100644 --- a/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "minddata/dataset/util/storage_manager.h" +#include "minddata/dataset/engine/cache/storage_manager.h" #include #include diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_manager.h b/mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.h similarity index 97% rename from mindspore/ccsrc/minddata/dataset/util/storage_manager.h rename to mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.h index 764ac83575..bd316f7fec 100644 --- a/mindspore/ccsrc/minddata/dataset/util/storage_manager.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/storage_manager.h @@ -21,6 +21,7 @@ #include #include #include +#include "minddata/dataset/engine/cache/storage_container.h" #include "minddata/dataset/util/allocator.h" #include "minddata/dataset/util/auto_index.h" #include "minddata/dataset/util/lock.h" @@ -28,7 +29,6 @@ #include "minddata/dataset/util/path.h" #include "minddata/dataset/util/service.h" #include "minddata/dataset/util/slice.h" -#include "minddata/dataset/util/storage_container.h" using ListOfContainers = std::vector>; namespace mindspore { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc index e970e6a69b..7ffcd3569e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc @@ -271,29 +271,18 @@ Status CacheBase::PrefetchRows(const std::vector &keys, std::vector } // Get the rows from the server TensorTable ttbl; - Status rc = cache_client_->GetRows(prefetch_keys, &ttbl); - if (rc.IsOk()) { - auto row_it = ttbl.begin(); - for (auto row_id : prefetch_keys) { - auto &row = *row_it; - if (row.empty()) { - cache_miss->push_back(row_id); - } - // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row - RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); - ++row_it; - } - } else { - // In case any thread is waiting for the rows to come back and blocked on a semaphore, - // we will put an empty row in the local cache. - for (auto row_id : prefetch_keys) { - TensorRow row; - row.setId(row_id); - RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); + RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl)); + auto row_it = ttbl.begin(); + for (auto row_id : prefetch_keys) { + auto &row = *row_it; + if (row.empty()) { cache_miss->push_back(row_id); } + // Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row + RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); + ++row_it; } - return rc; + return Status::OK(); } Status CacheBase::Prefetcher(int32_t worker_id) { @@ -322,6 +311,16 @@ Status CacheBase::Prefetcher(int32_t worker_id) { return rc; } } while (rc.IsNetWorkError()); + // In case any thread is waiting for the rows to come back and blocked on a semaphore, + // we will put an empty row in the local cache. + if (rc.IsError() && AllowCacheMiss()) { + for (auto row_id : prefetch_keys) { + TensorRow row; + row.setId(row_id); + RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row))); + cache_miss.push_back(row_id); + } + } } else { if (AllowCacheMiss()) { // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h index c9e1bfbe9f..af47748b1c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h @@ -24,7 +24,6 @@ #include #include "minddata/dataset/engine/connector.h" #include "minddata/dataset/engine/cache/cache_client.h" -#include "minddata/dataset/engine/cache/cache_service.h" #include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc index 0de3767f45..ca18109381 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc @@ -309,8 +309,7 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha if (st_.compare_exchange_strong(expected, State::kDirty)) { // We will do a deep copy but write directly into CacheRequest protobuf or shared memory Status rc; - cleaner_copy_ = - std::make_shared(cc->server_connection_id_, cc->cookie(), cc->SupportLocalClient()); + cleaner_copy_ = std::make_shared(cc.get()); rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row); if (rc.IsOk()) { // Send the request async. The cleaner will check the return code. diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc index 04abeb989e..f579bf165c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc @@ -153,7 +153,7 @@ Status CacheOp::WaitForCachingAllRows() { bool BuildPhaseDone = true; do { RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); - BuildPhaseDone = stat.cache_service_state == static_cast(CacheService::State::kFetchPhase); + BuildPhaseDone = stat.cache_service_state == static_cast(CacheServiceState::kFetchPhase); if (!BuildPhaseDone) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc index b3bb275dcf..a1c6e4e012 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace dataset { // Constructor -CacheErrorPass::CacheErrorPass() : is_cached_(false) {} +CacheErrorPass::CacheErrorPass() : is_cached_(false), is_mappable_(false) {} // Identifies the subtree below this node as being cached Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *modified) { @@ -75,5 +75,81 @@ Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *modifi return Status::OK(); } #endif + +Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that this is a tree with mappable leaf dataset + is_mappable_ = true; + return Status::OK(); +} + +Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that this is a tree with mappable leaf dataset + is_mappable_ = true; + return Status::OK(); +} + +Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that this is a tree with mappable leaf dataset + is_mappable_ = true; + return Status::OK(); +} + +Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that this is a tree with mappable leaf dataset + is_mappable_ = true; + return Status::OK(); +} + +Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that this is a tree with mappable leaf dataset + is_mappable_ = true; + return Status::OK(); +} + +Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that this is a tree with mappable leaf dataset + is_mappable_ = true; + return Status::OK(); +} + +Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that this is a tree with mappable leaf dataset + is_mappable_ = true; + return Status::OK(); +} + +Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that this is a tree with mappable leaf dataset + is_mappable_ = true; + return Status::OK(); +} + +Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that this is a tree with mappable leaf dataset + is_mappable_ = true; + return Status::OK(); +} + +Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that this is a tree with mappable leaf dataset + is_mappable_ = true; + return Status::OK(); +} + +Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Turn off the flag that we're under a merge op + is_cached_ = false; + return Status::OK(); +} + +// Currently, returns an error if RepeatOp exists under a cache +// Because there is no operator in the cache hit stream to consume eoes, caching above repeat causes problem. +Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) { + if (is_cached_ && is_mappable_) { + RETURN_STATUS_UNEXPECTED("Repeat is not supported as a descendant operator under a mappable cache."); + } + + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h index 5a20d2f35e..1e5816eb2a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h @@ -67,8 +67,81 @@ class CacheErrorPass : public NodePass { Status PreRunOnNode(std::shared_ptr node, bool *modified) override; #endif + /// \brief Identifies the leaf dataset as being mappable + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the leaf dataset as being mappable + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the leaf dataset as being mappable + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the leaf dataset as being mappable + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the leaf dataset as being mappable + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the leaf dataset as being mappable + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the leaf dataset as being mappable + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the leaf dataset as being mappable + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the leaf dataset as being mappable + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the leaf dataset as being mappable + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the subtree above this node as not being cached + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies and block repeat under cache scenarios + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + private: bool is_cached_; + bool is_mappable_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/util/CMakeLists.txt index f6af32fd28..2edae57720 100644 --- a/mindspore/ccsrc/minddata/dataset/util/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/util/CMakeLists.txt @@ -3,7 +3,6 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE add_library(utils OBJECT arena.cc buddy.cc - cache_pool.cc circular_pool.cc data_helper.cc memory_pool.cc @@ -16,8 +15,6 @@ add_library(utils OBJECT lock.cc semaphore.cc status.cc - storage_container.cc - storage_manager.cc slice.cc path.cc wait_post.cc diff --git a/mindspore/ccsrc/minddata/dataset/util/allocator.h b/mindspore/ccsrc/minddata/dataset/util/allocator.h index 220d0a6d1b..2d006b418f 100644 --- a/mindspore/ccsrc/minddata/dataset/util/allocator.h +++ b/mindspore/ccsrc/minddata/dataset/util/allocator.h @@ -94,6 +94,11 @@ Status MakeUnique(std::unique_ptr> *out, C alloc, CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "size must be positive"); try { T *data = alloc.allocate(n); + // Some of our implementation of allocator (e.g. NumaAllocator) don't throw std::bad_alloc. + // So we have to catch for null ptr + if (data == nullptr) { + return Status(StatusCode::kOutOfMemory); + } if (!std::is_arithmetic::value) { for (auto i = 0; i < n; i++) { std::allocator_traits::construct(alloc, &(data[i]), std::forward(args)...); diff --git a/mindspore/ccsrc/minddata/dataset/util/path.h b/mindspore/ccsrc/minddata/dataset/util/path.h index 17dc015c70..5e7f9cd7b1 100644 --- a/mindspore/ccsrc/minddata/dataset/util/path.h +++ b/mindspore/ccsrc/minddata/dataset/util/path.h @@ -78,6 +78,18 @@ class Path { Path operator/(const char *); + bool operator==(const Path &rhs) const { return (path_ == rhs.path_); } + + bool operator!=(const Path &rhs) const { return (path_ != rhs.path_); } + + bool operator<(const Path &rhs) const { return (path_ < rhs.path_); } + + bool operator>(const Path &rhs) const { return (path_ > rhs.path_); } + + bool operator<=(const Path &rhs) const { return (path_ <= rhs.path_); } + + bool operator>=(const Path &rhs) const { return (path_ >= rhs.path_); } + bool Exists(); bool IsDirectory(); diff --git a/mindspore/ccsrc/minddata/dataset/util/task.cc b/mindspore/ccsrc/minddata/dataset/util/task.cc index 77336e56ae..d6e628670f 100644 --- a/mindspore/ccsrc/minddata/dataset/util/task.cc +++ b/mindspore/ccsrc/minddata/dataset/util/task.cc @@ -37,6 +37,11 @@ void Task::operator()() { ss << Services::GetUniqueID(); #endif MS_LOG(DEBUG) << my_name_ << " Thread ID " << ss.str() << " Started."; + +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) + native_handle_ = pthread_self(); +#endif + try { // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set // the TaskGroup pointer and register. We move the registration logic to here (after we spawn) so we can @@ -96,7 +101,8 @@ Task::Task(const std::string &myName, const std::function &f) task_group_(nullptr), is_master_(false), running_(false), - caught_severe_exception_(false) { + caught_severe_exception_(false), + native_handle_(0) { IntrpResource::ResetIntrpState(); wp_.ResetIntrpState(); wp_.Clear(); @@ -164,5 +170,10 @@ Status Task::OverrideInterruptRc(const Status &rc) { } return rc; } + +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) +pthread_t Task::GetNativeHandle() const { return native_handle_; } +#endif + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/task.h b/mindspore/ccsrc/minddata/dataset/util/task.h index b41861ecc0..ed327101f4 100644 --- a/mindspore/ccsrc/minddata/dataset/util/task.h +++ b/mindspore/ccsrc/minddata/dataset/util/task.h @@ -16,6 +16,9 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_H_ +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) +#include +#endif #include #include #include @@ -84,7 +87,7 @@ class Task : public IntrpResource { std::thread::id get_id() { return id_; } - std::string MyName() { return my_name_; } + std::string MyName() const { return my_name_; } // An operator used by std::find bool operator==(const Task &other) const { return (this == &other); } @@ -97,6 +100,10 @@ class Task : public IntrpResource { static Status OverrideInterruptRc(const Status &rc); +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) + pthread_t GetNativeHandle() const; +#endif + private: mutable std::mutex mux_; std::string my_name_; @@ -113,6 +120,12 @@ class Task : public IntrpResource { volatile bool running_; volatile bool caught_severe_exception_; +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) + pthread_t native_handle_; +#else + uint64_t native_handle_; +#endif + void ShutdownGroup(); TaskGroup *MyTaskGroup(); void set_task_group(TaskGroup *vg); diff --git a/tests/ut/cpp/dataset/cache_op_test.cc b/tests/ut/cpp/dataset/cache_op_test.cc index fcfc42a1be..1cd4328c5f 100644 --- a/tests/ut/cpp/dataset/cache_op_test.cc +++ b/tests/ut/cpp/dataset/cache_op_test.cc @@ -24,7 +24,6 @@ #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "minddata/dataset/util/storage_container.h" // lint !e322 #include "minddata/dataset/engine/datasetops/source/random_data_op.h" #include "minddata/dataset/engine/data_schema.h" diff --git a/tests/ut/python/cachetests/cachetest_py.sh b/tests/ut/python/cachetests/cachetest_py.sh index a65ff8855c..6b7949471b 100755 --- a/tests/ut/python/cachetests/cachetest_py.sh +++ b/tests/ut/python/cachetests/cachetest_py.sh @@ -31,7 +31,7 @@ HandleRcExit $? 1 1 export RUN_CACHE_TEST=TRUE # Each of these tests will create session, use it, then destroy it after the test -for i in $(seq 1 6) +for i in $(seq 1 5) do test_name="test_cache_map_basic${i}" GetSession @@ -121,6 +121,12 @@ HandleRcExit $? 0 0 PytestCmd "test_cache_map.py" "test_cache_map_voc" 1 HandleRcExit $? 0 0 +PytestCmd "test_cache_map.py" "test_cache_map_python_sampler" 1 +HandleRcExit $? 0 0 + +PytestCmd "test_cache_map.py" "test_cache_map_nested_repeat" +HandleRcExit $? 0 0 + # Run two parallel pipelines (sharing cache) for i in $(seq 1 2) do @@ -309,6 +315,9 @@ HandleRcExit $? 0 0 PytestCmd "test_cache_nomap.py" "test_cache_nomap_textfile" 1 HandleRcExit $? 0 0 +PytestCmd "test_cache_nomap.py" "test_cache_nomap_nested_repeat" +HandleRcExit $? 0 0 + for i in $(seq 1 3) do test_name="test_cache_nomap_multiple_cache${i}" diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 452fb1271f..30c72808fa 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -107,49 +107,10 @@ def test_cache_map_basic2(): @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_basic3(): - """ - Test a repeat under mappable cache - - Cache - | - Map(decode) - | - Repeat - | - ImageFolder - """ - - logger.info("Test cache basic 3") - if "SESSION_ID" in os.environ: - session_id = int(os.environ['SESSION_ID']) - else: - raise RuntimeError("Testcase requires SESSION_ID environment variable") - - some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) - - # This DATA_DIR only has 2 images in it - ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) - decode_op = c_vision.Decode() - ds1 = ds1.repeat(4) - ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) - logger.info("ds1.dataset_size is ", ds1.get_dataset_size()) - - num_iter = 0 - for _ in ds1.create_dict_iterator(num_epochs=1): - logger.info("get data from dataset") - num_iter += 1 - - logger.info("Number of data in ds1: {} ".format(num_iter)) - assert num_iter == 8 - logger.info('test_cache_basic3 Ended.\n') - - -@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") -def test_cache_map_basic4(): """ Test different rows result in core dump """ - logger.info("Test cache basic 4") + logger.info("Test cache basic 3") if "SESSION_ID" in os.environ: session_id = int(os.environ['SESSION_ID']) else: @@ -171,11 +132,11 @@ def test_cache_map_basic4(): logger.info("Number of data in ds1: {} ".format(num_iter)) assert num_iter == 8 - logger.info('test_cache_basic4 Ended.\n') + logger.info('test_cache_basic3 Ended.\n') @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") -def test_cache_map_basic5(): +def test_cache_map_basic4(): """ Test Map with non-deterministic TensorOps above cache @@ -188,7 +149,7 @@ def test_cache_map_basic5(): ImageFolder """ - logger.info("Test cache failure 5") + logger.info("Test cache basic 4") if "SESSION_ID" in os.environ: session_id = int(os.environ['SESSION_ID']) else: @@ -211,11 +172,11 @@ def test_cache_map_basic5(): logger.info("Number of data in ds1: {} ".format(num_iter)) assert num_iter == 8 - logger.info('test_cache_failure5 Ended.\n') + logger.info('test_cache_basic4 Ended.\n') @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") -def test_cache_map_basic6(): +def test_cache_map_basic5(): """ Test cache as root node @@ -223,7 +184,7 @@ def test_cache_map_basic6(): | ImageFolder """ - logger.info("Test cache basic 6") + logger.info("Test cache basic 5") if "SESSION_ID" in os.environ: session_id = int(os.environ['SESSION_ID']) else: @@ -239,7 +200,7 @@ def test_cache_map_basic6(): logger.info("Number of data in ds1: {} ".format(num_iter)) assert num_iter == 2 - logger.info('test_cache_basic6 Ended.\n') + logger.info('test_cache_basic5 Ended.\n') @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") @@ -502,6 +463,7 @@ def test_cache_map_failure7(): Generator """ + def generator_1d(): for i in range(64): yield (np.array(i),) @@ -528,6 +490,44 @@ def test_cache_map_failure7(): logger.info('test_cache_failure7 Ended.\n') +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_failure8(): + """ + Test a repeat under mappable cache (failure) + + Cache + | + Map(decode) + | + Repeat + | + ImageFolder + """ + + logger.info("Test cache failure 8") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.repeat(4) + ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) + + with pytest.raises(RuntimeError) as e: + num_iter = 0 + for _ in ds1.create_dict_iterator(num_epochs=1): + num_iter += 1 + assert "Repeat is not supported as a descendant operator under a mappable cache" in str(e.value) + + assert num_iter == 0 + logger.info('test_cache_failure8 Ended.\n') + + @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_parameter_check(): """ @@ -1702,6 +1702,125 @@ def test_cache_map_voc2(): logger.info("test_cache_map_voc2 Ended.\n") +class ReverseSampler(ds.Sampler): + def __iter__(self): + for i in range(self.dataset_size - 1, -1, -1): + yield i + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_python_sampler1(): + """ + Test using a python sampler, and cache after leaf + + Repeat + | + Map(decode) + | + cache + | + ImageFolder + """ + + logger.info("Test cache map python sampler1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info("test_cache_map_python_sampler1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_python_sampler2(): + """ + Test using a python sampler, and cache after map + + Repeat + | + cache + | + Map(decode) + | + ImageFolder + """ + + logger.info("Test cache map python sampler2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler()) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info("test_cache_map_python_sampler2 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_nested_repeat(): + """ + Test cache on pipeline with nested repeat ops + + Repeat + | + Map(decode) + | + Repeat + | + Cache + | + ImageFolder + """ + + logger.info("Test cache map nested repeat") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.repeat(4) + ds1 = ds1.map(operations=decode_op, input_columns=["image"]) + ds1 = ds1.repeat(2) + + num_iter = 0 + for _ in ds1.create_dict_iterator(num_epochs=1): + logger.info("get data from dataset") + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 16 + logger.info('test_cache_map_nested_repeat Ended.\n') + + if __name__ == '__main__': test_cache_map_basic1() test_cache_map_basic2() diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py index bb1580d2da..2269ce47f9 100644 --- a/tests/ut/python/dataset/test_cache_nomap.py +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -1292,6 +1292,50 @@ def test_cache_nomap_epoch_ctrl3(): logger.info("test_cache_nomap_epoch_ctrl3 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_epoch_ctrl4(): + """ + Test using two-loops method with repeat under cache + + cache + | + Map(decode) + | + repeat + | + TFRecord + """ + + logger.info("Test cache nomap epoch ctrl4") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + ds1 = ds1.repeat(2) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + + num_epoch = 5 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + row_count = 0 + for _ in iter1: + row_count += 1 + logger.info("Number of data in ds1: {} ".format(row_count)) + assert row_count == 6 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_nomap_epoch_ctrl4 Ended.\n") + + @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_nomap_multiple_cache1(): """ @@ -1734,6 +1778,47 @@ def test_cache_nomap_textfile2(): logger.info("test_cache_nomap_textfile2 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_nested_repeat(): + """ + Test cache on pipeline with nested repeat ops + + Repeat + | + Cache + | + Map(decode) + | + Repeat + | + TFRecord + """ + + logger.info("Test cache nomap nested repeat") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.repeat(4) + ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) + ds1 = ds1.repeat(2) + + num_iter = 0 + for _ in ds1.create_dict_iterator(num_epochs=1): + logger.info("get data from dataset") + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 24 + logger.info('test_cache_nomap_nested_repeat Ended.\n') + + if __name__ == '__main__': test_cache_nomap_basic1() test_cache_nomap_basic2() diff --git a/tests/ut/python/test_server_stop_testcase.sh b/tests/ut/python/test_server_stop_testcase.sh new file mode 100755 index 0000000000..7187e00908 --- /dev/null +++ b/tests/ut/python/test_server_stop_testcase.sh @@ -0,0 +1,10 @@ +~/cache/cache_admin --start +session_id=$(~/cache/cache_admin -g | awk '{print $NF}') +export SESSION_ID=${session_id} +pytest dataset/test_cache_nomap.py::test_cache_nomap_server_stop & +pid=("$!") + +sleep 2 +~/cache/cache_admin --stop +sleep 1 +wait ${pid}