diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt index a0b4382dfd..225df785e1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt @@ -62,7 +62,7 @@ if (ENABLE_CACHE) endif () add_executable(cache_admin cache_admin.cc cache_admin_arg.cc) - target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES}) + target_link_libraries(cache_admin _c_dataengine _c_mindrecord mindspore::protobuf ${PYTHON_LIBRARIES} pthread) if (USE_GLOG) target_link_libraries(cache_admin mindspore::glog) diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc index a7e26b08de..fa92048c49 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc @@ -27,6 +27,8 @@ #include #include "minddata/dataset/engine/cache/cache_request.h" #include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/cache/cache_server.h" +#include "minddata/dataset/engine/cache/cache_ipc.h" #include "minddata/dataset/util/path.h" #include "minddata/dataset/core/constants.h" @@ -325,9 +327,33 @@ Status CacheAdminArgHandler::RunCommand() { Help(); break; } - case CommandId::kCmdStart: + case CommandId::kCmdStart: { + RETURN_IF_NOT_OK(StartServer(command_id_)); + break; + } case CommandId::kCmdStop: { - RETURN_IF_NOT_OK(StartStopServer(command_id_)); + CacheClientGreeter comm(hostname_, port_, 1); + RETURN_IF_NOT_OK(comm.ServiceStart()); + SharedMessage msg; + RETURN_IF_NOT_OK(msg.Create()); + auto rq = std::make_shared(msg.GetMsgQueueId()); + RETURN_IF_NOT_OK(comm.HandleRequest(rq)); + Status rc = rq->Wait(); + if (rc.IsError()) { + msg.RemoveResourcesOnExit(); + if (rc.IsNetWorkError()) { + std::string errMsg = "Server is not up or has been shutdown already."; + return Status(StatusCode::kNetWorkError, errMsg); + } + return rc; + } + // OK return code only means the server acknowledge our request but we still + // have to wait for its complete shutdown because the server will shutdown + // the comm layer as soon as the request is received, and we need to wait + // on the message queue instead. + // The server will remove the queue and we will then wake up. + Status dummy_rc; + (void)msg.ReceiveStatus(&dummy_rc); break; } case CommandId::kCmdGenerateSession: { @@ -396,7 +422,7 @@ Status CacheAdminArgHandler::RunCommand() { return Status::OK(); } -Status CacheAdminArgHandler::StartStopServer(CommandId command_id) { +Status CacheAdminArgHandler::StartServer(CommandId command_id) { // There currently does not exist any "install path" or method to identify which path the installed binaries will // exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the // cache_admin binary (this process). @@ -477,23 +503,15 @@ Status CacheAdminArgHandler::StartStopServer(CommandId command_id) { std::string memory_cap_ratio_string = std::to_string(memory_cap_ratio_); char *argv[9]; - if (command_id == CommandId::kCmdStart) { - argv[0] = cache_server_binary.data(); - argv[1] = spill_dir_.data(); - argv[2] = workers_string.data(); - argv[3] = port_string.data(); - argv[4] = shared_memory_string.data(); - argv[5] = minloglevel_string.data(); - argv[6] = daemonize_string.data(); - argv[7] = memory_cap_ratio_string.data(); - argv[8] = nullptr; - } else { - // We are doing a --stop. Change the name to '-' and we also need the port number. - // The rest we don't need. - argv[0] = std::string("-").data(); - argv[1] = port_string.data(); - argv[2] = nullptr; - } + argv[0] = cache_server_binary.data(); + argv[1] = spill_dir_.data(); + argv[2] = workers_string.data(); + argv[3] = port_string.data(); + argv[4] = shared_memory_string.data(); + argv[5] = minloglevel_string.data(); + argv[6] = daemonize_string.data(); + argv[7] = memory_cap_ratio_string.data(); + argv[8] = nullptr; // Now exec the binary execv(cache_server_binary.data(), argv); @@ -509,17 +527,27 @@ void CacheAdminArgHandler::Help() { std::cerr << "Syntax:\n"; std::cerr << " cache_admin [--start | --stop]\n"; std::cerr << " [ [-h | --hostname] ]\n"; + std::cerr << " Default is " << kCfgDefaultCacheHost << ".\n"; std::cerr << " [ [-p | --port] ]\n"; + std::cerr << " Possible values are in range [1025..65535].\n"; + std::cerr << " Default is " << kCfgDefaultCachePort << ".\n"; std::cerr << " [ [-g | --generate_session] ]\n"; std::cerr << " [ [-d | --destroy_session] ]\n"; std::cerr << " [ [-w | --workers] ]\n"; + std::cerr << " Possible values are in range [1...max(100, Number of CPU)].\n"; + std::cerr << " Default is " << kDefaultNumWorkers << ".\n"; std::cerr << " [ [-s | --spilldir] ]\n"; + std::cerr << " Default is " << kDefaultSpillDir << ".\n"; std::cerr << " [ [-l | --minloglevel] ]\n"; + std::cerr << " Possible values are 0, 1, 2 and 3.\n"; + std::cerr << " Default is 1 (info level).\n"; std::cerr << " [ --list_sessions ]\n"; // Do not expose these option to the user via help or documentation, but the options do exist to aid with // development and tuning. // std::cerr << " [ [-m | --shared_memory_size] ]\n"; + // std::cerr << " Default is " << kDefaultSharedMemorySizeInGB << " (Gb in unit).\n"; // std::cerr << " [ [-r | --memory_cap_ratio] ]\n"; + // std::cerr << " Default is " << kMemoryCapRatio << ".\n"; std::cerr << " [--help]" << std::endl; } } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h index 020cb1b415..f3ba121a8d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h @@ -78,7 +78,7 @@ class CacheAdminArgHandler { kArgNumArgs = 14 // Must be the last position to provide a count }; - Status StartStopServer(CommandId); + Status StartServer(CommandId command_id); Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream, CommandId command_id = CommandId::kCmdUnknown); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc index a40362031b..e7096db469 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc @@ -21,11 +21,7 @@ CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) // We create the shared memory and we will destroy it. All other client just detach only. shm_.RemoveResourcesOnExit(); } -CachedSharedMemoryArena::~CachedSharedMemoryArena() { - // Also remove the path we use to generate ftok. - Path p(PortToUnixSocketPath(port_)); - (void)p.Remove(); -} +CachedSharedMemoryArena::~CachedSharedMemoryArena() {} Status CachedSharedMemoryArena::CreateArena(std::unique_ptr *out, int32_t port, size_t val_in_GB) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h index 61e430aa1f..6f61960e69 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h @@ -72,6 +72,9 @@ class CachedSharedMemoryArena : public MemoryPool { return os; } + /// \brief Get the shared memory key of the shared memory + SharedMemory::shm_key_t GetKey() const { return shm_.GetKey(); } + private: mutable std::mutex mux_; int32_t val_in_GB_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc index 20804adaec..8aaf98fde1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include "minddata/dataset/engine/cache/cache_grpc_server.h" #include "minddata/dataset/engine/cache/cache_server.h" #include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/task_manager.h" #ifndef ENABLE_ANDROID #include "utils/log_adapter.h" #else @@ -25,7 +27,7 @@ namespace mindspore { namespace dataset { CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb) - : port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb) { + : port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb), shm_key_(-1) { // Setup a path for unix socket. unix_socket_ = PortToUnixSocketPath(port); // We can't generate the ftok key yet until the unix_socket_ is created @@ -73,7 +75,8 @@ Status CacheServerGreeterImpl::Run() { MS_LOG(INFO) << "Server listening on " << server_address; #if CACHE_LOCAL_CLIENT RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_)); - MS_LOG(INFO) << "Creation of local socket and shared memory successful"; + shm_key_ = shm_pool_->GetKey(); + MS_LOG(INFO) << "Creation of local socket and shared memory successful. Shared memory key " << shm_key_; auto cs = CacheServer::GetInstance().GetHWControl(); // This shared memory is a hot memory and we will interleave among all the numa nodes. cs->InterleaveMemory(const_cast(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L); @@ -181,5 +184,32 @@ void CacheServerRequest::Print(std::ostream &out) const { out << " "; BaseRequest::Print(out); } + +Status CacheServerGreeterImpl::MonitorUnixSocket() { + TaskManager::FindMe()->Post(); +#if CACHE_LOCAL_CLIENT + Path p(unix_socket_); + do { + RETURN_IF_INTERRUPTED(); + // If the unix socket is recreated for whatever reason, this server instance will be stale and + // no other process and communicate with us. In this case we need to shutdown ourselves. + if (p.Exists()) { + SharedMemory::shm_key_t key; + RETURN_IF_NOT_OK(PortToFtok(port_, &key)); + if (key != shm_key_) { + std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key_) + + ". New key " + std::to_string(key) + ". Shutting down server"; + MS_LOG(ERROR) << errMsg; + RETURN_STATUS_UNEXPECTED(errMsg); + } + } else { + MS_LOG(WARNING) << "Unix socket is removed."; + TaskManager::WakeUpWatchDog(); + } + std::this_thread::sleep_for(std::chrono::seconds(5)); + } while (true); +#endif + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h index a85ec8da07..d8bf2ed6fd 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h @@ -87,6 +87,9 @@ class CacheServerGreeterImpl final { /// \return Return the shared memory pool CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); } + /// \brief Montor the status of the unix socket in case it is gone. + Status MonitorUnixSocket(); + void Shutdown(); private: @@ -97,6 +100,7 @@ class CacheServerGreeterImpl final { std::unique_ptr cq_; std::unique_ptr server_; std::unique_ptr shm_pool_; + SharedMemory::shm_key_t shm_key_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.h index ce174ff808..946104686a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_ipc.h @@ -160,6 +160,9 @@ class SharedMemory : public BaseIPC { /// \brief Set the public key void SetPublicKey(key_t public_key) { shm_key_ = public_key; } + /// \brief Retrieve the key + shm_key_t GetKey() const { return shm_key_; } + /// \brief This returns where we attach to the shared memory. /// \return Base address of the shared memory. const void *SharedMemoryBaseAddr() const { return shmat_addr_; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc index 868f5f445b..42283dd534 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc @@ -27,116 +27,6 @@ #include "minddata/dataset/engine/cache/cache_ipc.h" namespace ds = mindspore::dataset; -/// Send a synchronous command to the local server using tcp/ip. -/// We aren't using any client code because this binary is not necessarily linked with the client library. -/// So just using grpc call directly. -/// \param port tcp/ip port to use -/// \param type Type of command. -/// \param out grpc result -/// \return Status object -ds::Status SendSyncCommand(int32_t port, ds::BaseRequest::RequestType type, ds::CacheRequest *rq, ds::CacheReply *reply, - grpc::Status *out) { - if (rq == nullptr) { - return ds::Status(ds::StatusCode::kUnexpectedError, "pointer rq is null"); - } - if (reply == nullptr) { - return ds::Status(ds::StatusCode::kUnexpectedError, "pointer reply is null"); - } - if (out == nullptr) { - return ds::Status(ds::StatusCode::kUnexpectedError, "pointer out is null"); - } - const std::string hostname = "127.0.0.1"; - auto unix_socket = ds::PortToUnixSocketPath(port); -#if CACHE_LOCAL_CLIENT - const std::string target = "unix://" + unix_socket; -#else - const std::string target = hostname + ":" + std::to_string(port); -#endif - try { - rq->set_type(static_cast(type)); - rq->set_client_id(-1); - rq->set_flag(0); - grpc::ChannelArguments args; - grpc::ClientContext ctx; - grpc::CompletionQueue cq; - // Standard async rpc call - std::shared_ptr channel = - grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args); - std::unique_ptr stub = ds::CacheServerGreeter::NewStub(channel); - std::unique_ptr> rpc = - stub->PrepareAsyncCacheServerRequest(&ctx, *rq, &cq); - rpc->StartCall(); - // We need to pass a tag. But since this is the only request in the completion queue and so we - // just pass a dummy - int64_t dummy; - void *tag; - bool success; - rpc->Finish(reply, out, &dummy); - // Now we wait on the completion queue synchronously. - auto r = cq.Next(&tag, &success); - if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) { - if (!success || tag != &dummy) { - std::string errMsg = "Unexpected programming error "; - return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } - if (out->ok()) { - return ds::Status(static_cast(reply->rc()), reply->msg()); - } else { - auto error_code = out->error_code(); - std::string errMsg = out->error_message() + ". GRPC Code " + std::to_string(error_code); - return ds::Status(ds::StatusCode::kNetWorkError, errMsg); - } - } else { - std::string errMsg = "Unexpected queue rc = " + std::to_string(r); - return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } - } catch (const std::exception &e) { - return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); - } -} - -/// Stop the server -/// \param argv -/// \return Status object -ds::Status StopServer(int argc, char **argv) { - ds::Status rc; - ds::CacheServer::Builder builder; - std::string errMsg; - if (argc != 2) { - return ds::Status(ds::StatusCode::kSyntaxError); - } - int32_t port = strtol(argv[1], nullptr, 10); - // We will go through the builder to do some snaity check. We only need the port number - // to shut down the server. Null the root directory as we don't trigger the sanity code to write out anything - // to the spill directory. - builder.SetPort(port).SetRootDirectory(""); - // Part of the sanity check is check the shared memory. If the server is up and running, we expect - // the return code is kDuplicate. - rc = builder.SanityCheck(); - if (rc.IsOk()) { - errMsg = "Server is not up or has been shutdown already."; - return ds::Status(ds::StatusCode::kUnexpectedError, errMsg); - } else if (rc.get_code() != ds::StatusCode::kDuplicateKey) { - // Not OK, and no duplicate, just return the rc whatever it is. - return rc; - } else { - // Now we get some work to do. We will send a tcp/ip request to the given port. - // This binary is not linked with client side of code, so we will just call grpc directly. - ds::CacheRequest rq; - ds::CacheReply reply; - grpc::Status grpc_rc; - rc = SendSyncCommand(port, ds::BaseRequest::RequestType::kStopService, &rq, &reply, &grpc_rc); - // The request is like a self destruct message, the server will not send anything back and - // shutdown all incoming request. So we should expect some unexpected network error if - // all goes well and we expect to GRPC code 14. - auto err_code = grpc_rc.error_code(); - if (rc.get_code() != ds::StatusCode::kNetWorkError || err_code != grpc::StatusCode::UNAVAILABLE) { - return ds::Status(ds::StatusCode::kUnexpectedError, __LINE__, __FILE__); - } - } - return ds::Status::OK(); -} - /// Start the server /// \param argv /// \return Status object @@ -235,15 +125,8 @@ ds::Status StartServer(int argc, char **argv) { } int main(int argc, char **argv) { - ds::Status rc; - ds::CacheServer::Builder builder; - // This executable is not to be called directly, and should be invoked by cache_admin executable. - if (strcmp(argv[0], "-") == 0) { - rc = StopServer(argc, argv); - } else { - rc = StartServer(argc, argv); - } + ds::Status rc = StartServer(argc, argv); // Check result if (rc.IsError()) { auto errCode = rc.get_code(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc index fdec89a590..e38d138434 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc @@ -20,6 +20,7 @@ #include #endif #include +#include #include #include "minddata/dataset/core/constants.h" #include "minddata/dataset/engine/cache/cache_client.h" @@ -326,5 +327,11 @@ Status ListSessionsRequest::PostReply() { return Status::OK(); } + +Status ServerStopRequest::PostReply() { + CHECK_FAIL_RETURN_UNEXPECTED(strcmp(reply_.result().data(), "OK") == 0, "Not the right response"); + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h index 43cd66f852..059dd26c4d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h @@ -394,6 +394,15 @@ class ToggleWriteModeRequest : public BaseRequest { ~ToggleWriteModeRequest() override = default; }; +class ServerStopRequest : public BaseRequest { + public: + friend class CacheServer; + explicit ServerStopRequest(int32_t qID) : BaseRequest(RequestType::kStopService) { + rq_.add_buf_data(std::to_string(qID)); + } + Status PostReply() override; +}; + class ConnectResetRequest : public BaseRequest { public: friend class CacheServer; diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc index 9bdac9f433..7445c65ddf 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc @@ -108,6 +108,9 @@ Status CacheServer::DoServiceStart() { try { comm_layer_ = std::make_shared(port_, shared_memory_sz_in_gb_); RETURN_IF_NOT_OK(comm_layer_->Run()); + // Bring up a thread to monitor the unix socket in case it is removed. + auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get()); + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f)); } catch (const std::exception &e) { RETURN_STATUS_UNEXPECTED(e.what()); } @@ -154,6 +157,15 @@ Status CacheServer::DoServiceStop() { } ++it; } + // Also remove the path we use to generate ftok. + Path p(PortToUnixSocketPath(port_)); + (void)p.Remove(); + // Finally wake up cache_admin if it is waiting + for (int32_t qID : shutdown_qIDs_) { + SharedMessage msg(qID); + msg.RemoveResourcesOnExit(); + // Let msg goes out of scope which will destroy the queue. + } return rc; } @@ -374,6 +386,68 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { return rc; } +Status CacheServer::BatchFetch(const std::shared_ptr &fbb, WritableSlice *out) { + RETURN_UNEXPECTED_IF_NULL(out); + int32_t numQ = GetNumGrpcWorkers(); + auto rng = GetRandomDevice(); + std::uniform_int_distribution distribution(0, numQ - 1); + int32_t qID = distribution(rng); + std::vector cache_rq_list; + auto p = flatbuffers::GetRoot(fbb->GetBufferPointer()); + const auto num_elements = p->rows()->size(); + auto connection_id = p->connection_id(); + cache_rq_list.reserve(num_elements); + int64_t data_offset = (num_elements + 1) * sizeof(int64_t); + auto *offset_array = reinterpret_cast(out->GetMutablePointer()); + offset_array[0] = data_offset; + for (auto i = 0; i < num_elements; ++i) { + auto data_locator = p->rows()->Get(i); + auto node_id = data_locator->node_id(); + size_t sz = data_locator->size(); + void *source_addr = reinterpret_cast(data_locator->addr()); + auto key = data_locator->key(); + // Please read the comment in CacheServer::BatchFetchRows where we allocate + // the buffer big enough so each thread (which we are going to dispatch) will + // not run into false sharing problem. We are going to round up sz to 4k. + auto sz_4k = round_up_4K(sz); + offset_array[i + 1] = offset_array[i] + sz_4k; + if (sz > 0) { + WritableSlice row_data(*out, offset_array[i], sz); + // Get a request and send to the proper worker (at some numa node) to do the fetch. + worker_id_t worker_id = IsNumaAffinityOn() ? GetWorkerByNumaId(node_id) : GetRandomWorker(); + CacheServerRequest *cache_rq; + RETURN_IF_NOT_OK(GetFreeRequestTag(qID++ % numQ, &cache_rq)); + cache_rq_list.push_back(cache_rq); + // Set up all the necessarily field. + cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow; + cache_rq->st_ = CacheServerRequest::STATE::PROCESS; + cache_rq->rq_.set_connection_id(connection_id); + cache_rq->rq_.set_type(static_cast(cache_rq->type_)); + auto dest_addr = row_data.GetMutablePointer(); + flatbuffers::FlatBufferBuilder fb2; + FetchRowMsgBuilder bld(fb2); + bld.add_key(key); + bld.add_size(sz); + bld.add_source_addr(reinterpret_cast(source_addr)); + bld.add_dest_addr(reinterpret_cast(dest_addr)); + auto offset = bld.Finish(); + fb2.Finish(offset); + cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize()); + RETURN_IF_NOT_OK(PushRequest(worker_id, cache_rq)); + } + } + // Now wait for all of them to come back. + Status rc; + for (CacheServerRequest *rq : cache_rq_list) { + RETURN_IF_NOT_OK(rq->Wait()); + if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) { + rc = rq->rc_; + } + RETURN_IF_NOT_OK(ReturnRequestTag(rq)); + } + return rc; +} + Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { auto connection_id = rq->connection_id(); // Hold the shared lock to prevent the cache from being dropped. @@ -394,6 +468,9 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { } std::shared_ptr fbb = std::make_shared(); RETURN_IF_NOT_OK(cs->PreBatchFetch(connection_id, row_id, fbb)); + // Let go of the shared lock. We don't need to interact with the CacheService anymore. + // We shouldn't be holding any lock while we can wait for a long time for the rows to come back. + lck.Unlock(); auto locator = flatbuffers::GetRoot(fbb->GetBufferPointer()); int64_t mem_sz = sizeof(int64_t) * (sz + 1); for (auto i = 0; i < sz; ++i) { @@ -418,7 +495,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { void *q = nullptr; RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q)); WritableSlice dest(q, mem_sz); - Status rc = cs->BatchFetch(fbb, &dest); + Status rc = BatchFetch(fbb, &dest); if (rc.IsError()) { shared_pool->Deallocate(q); return rc; @@ -439,7 +516,7 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { return Status(StatusCode::kOutOfMemory); } WritableSlice dest(mem.data(), mem_sz); - RETURN_IF_NOT_OK(cs->BatchFetch(fbb, &dest)); + RETURN_IF_NOT_OK(BatchFetch(fbb, &dest)); reply->set_result(std::move(mem)); } } @@ -721,7 +798,7 @@ Status CacheServer::ServerRequest(worker_id_t worker_id) { } case BaseRequest::RequestType::kStopService: { // This command shutdowns everything. - cache_req->rc_ = GlobalShutdown(); + cache_req->rc_ = GlobalShutdown(cache_req); break; } case BaseRequest::RequestType::kHeartBeat: { @@ -914,7 +991,25 @@ Status CacheServer::RpcRequest(worker_id_t worker_id) { return Status::OK(); } -Status CacheServer::GlobalShutdown() { +Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) { + auto *rq = &cache_req->rq_; + auto *reply = &cache_req->reply_; + if (!rq->buf_data().empty()) { + // cache_admin sends us a message qID and we will destroy the + // queue in our destructor and this will wake up cache_admin. + // But we don't want the cache_admin blindly just block itself. + // So we will send back an ack before shutdown the comm layer. + try { + int32_t qID = std::stoi(rq->buf_data(0)); + shutdown_qIDs_.push_back(qID); + } catch (const std::exception &e) { + // ignore it. + } + } + reply->set_result("OK"); + Status2CacheReply(cache_req->rc_, reply); + cache_req->st_ = CacheServerRequest::STATE::FINISH; + cache_req->responder_.Finish(*reply, grpc::Status::OK, cache_req); // Let's shutdown in proper order. bool expected = false; if (global_shutdown_.compare_exchange_strong(expected, true)) { @@ -939,7 +1034,7 @@ Status CacheServer::GlobalShutdown() { return Status::OK(); } -worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) { +worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) const { auto num_numa_nodes = GetNumaNodeCount(); MS_ASSERT(numa_id < num_numa_nodes); auto num_workers_per_node = GetNumWorkers() / num_numa_nodes; @@ -951,7 +1046,7 @@ worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) { return worker_id; } -worker_id_t CacheServer::GetRandomWorker() { +worker_id_t CacheServer::GetRandomWorker() const { std::mt19937 gen = GetRandomDevice(); std::uniform_int_distribution dist(0, num_workers_ - 1); return dist(gen); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h index 85407d44e1..2fa07ee4d5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h @@ -187,11 +187,11 @@ class CacheServer : public Service { /// \brief Assign a worker by a numa id /// \return worker id - worker_id_t GetWorkerByNumaId(numa_id_t node_id); + worker_id_t GetWorkerByNumaId(numa_id_t node_id) const; /// \brief Randomly pick a worker /// \return worker id - worker_id_t GetRandomWorker(); + worker_id_t GetRandomWorker() const; /// \brief Check if we bind threads to numa cores bool IsNumaAffinityOn() const { return numa_affinity_; } @@ -227,6 +227,7 @@ class CacheServer : public Service { std::shared_ptr hw_info_; std::map numa_tasks_; bool numa_affinity_; + std::vector shutdown_qIDs_; /// \brief Constructor /// \param spill_path Top directory for spilling buffers to. @@ -314,7 +315,7 @@ class CacheServer : public Service { /// \brief A proper shutdown of the server /// \return Status object - Status GlobalShutdown(); + Status GlobalShutdown(CacheServerRequest *); /// \brief Find keys that will be cache miss /// \return Status object @@ -330,6 +331,13 @@ class CacheServer : public Service { /// \brief Connect request by a pipeline Status ConnectReset(CacheRequest *rq); + + /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded + /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. + /// \param[in] v A vector of row id. + /// \param[out] out A contiguous memory buffer that holds the requested rows. + /// \return Status object + Status BatchFetch(const std::shared_ptr &fbb, WritableSlice *out); }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc index 624a279ba3..ee6fac25be 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc @@ -209,6 +209,10 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) { Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector &v, const std::shared_ptr &fbb) { SharedLock rw(&rw_lock_); + if (st_ == CacheServiceState::kBuildPhase) { + // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } std::vector> datalocator_v; datalocator_v.reserve(v.size()); for (auto row_id : v) { @@ -225,76 +229,6 @@ Status CacheService::PreBatchFetch(connection_id_type connection_id, const std:: return Status::OK(); } -Status CacheService::BatchFetch(const std::shared_ptr &fbb, WritableSlice *out) const { - RETURN_UNEXPECTED_IF_NULL(out); - SharedLock rw(&rw_lock_); - if (st_ == CacheServiceState::kBuildPhase) { - // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. - RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); - } - CacheServer &cs = CacheServer::GetInstance(); - int32_t numQ = cs.GetNumGrpcWorkers(); - auto rng = GetRandomDevice(); - std::uniform_int_distribution distribution(0, numQ - 1); - int32_t qID = distribution(rng); - std::vector cache_rq_list; - auto p = flatbuffers::GetRoot(fbb->GetBufferPointer()); - const auto num_elements = p->rows()->size(); - auto connection_id = p->connection_id(); - cache_rq_list.reserve(num_elements); - int64_t data_offset = (num_elements + 1) * sizeof(int64_t); - auto *offset_array = reinterpret_cast(out->GetMutablePointer()); - offset_array[0] = data_offset; - for (auto i = 0; i < num_elements; ++i) { - auto data_locator = p->rows()->Get(i); - auto node_id = data_locator->node_id(); - size_t sz = data_locator->size(); - void *source_addr = reinterpret_cast(data_locator->addr()); - auto key = data_locator->key(); - // Please read the comment in CacheServer::BatchFetchRows where we allocate - // the buffer big enough so each thread (which we are going to dispatch) will - // not run into false sharing problem. We are going to round up sz to 4k. - auto sz_4k = round_up_4K(sz); - offset_array[i + 1] = offset_array[i] + sz_4k; - if (sz > 0) { - WritableSlice row_data(*out, offset_array[i], sz); - // Get a request and send to the proper worker (at some numa node) to do the fetch. - worker_id_t worker_id = cs.IsNumaAffinityOn() ? cs.GetWorkerByNumaId(node_id) : cs.GetRandomWorker(); - CacheServerRequest *cache_rq; - RETURN_IF_NOT_OK(cs.GetFreeRequestTag(qID++ % numQ, &cache_rq)); - cache_rq_list.push_back(cache_rq); - // Set up all the necessarily field. - cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow; - cache_rq->st_ = CacheServerRequest::STATE::PROCESS; - cache_rq->rq_.set_connection_id(connection_id); - cache_rq->rq_.set_type(static_cast(cache_rq->type_)); - auto dest_addr = row_data.GetMutablePointer(); - flatbuffers::FlatBufferBuilder fb2; - FetchRowMsgBuilder bld(fb2); - bld.add_key(key); - bld.add_size(sz); - bld.add_source_addr(reinterpret_cast(source_addr)); - bld.add_dest_addr(reinterpret_cast(dest_addr)); - auto offset = bld.Finish(); - fb2.Finish(offset); - cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize()); - RETURN_IF_NOT_OK(cs.PushRequest(worker_id, cache_rq)); - } - } - // Now wait for all of them to come back. Let go of the shared lock. We shouldn't be holding - // any lock while we can wait for a long time. - rw.Unlock(); - Status rc; - for (CacheServerRequest *rq : cache_rq_list) { - RETURN_IF_NOT_OK(rq->Wait()); - if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) { - rc = rq->rc_; - } - RETURN_IF_NOT_OK(cs.ReturnRequestTag(rq)); - } - return rc; -} - Status CacheService::InternalFetchRow(const FetchRowMsg *p) { RETURN_UNEXPECTED_IF_NULL(p); SharedLock rw(&rw_lock_); diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h index 474cf526c2..d1b10f7c4d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h @@ -75,13 +75,6 @@ class CacheService : public Service { Status PreBatchFetch(connection_id_type connection_id, const std::vector &v, const std::shared_ptr &); - /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded - /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. - /// \param[in] v A vector of row id. - /// \param[out] out A contiguous memory buffer that holds the requested rows. - /// \return Status object - Status BatchFetch(const std::shared_ptr &, WritableSlice *out) const; - /// \brief Getter function /// \return Spilling path Path GetSpillPath() const; diff --git a/mindspore/ccsrc/minddata/dataset/util/slice.h b/mindspore/ccsrc/minddata/dataset/util/slice.h index 058b822332..de37037c79 100644 --- a/mindspore/ccsrc/minddata/dataset/util/slice.h +++ b/mindspore/ccsrc/minddata/dataset/util/slice.h @@ -87,6 +87,7 @@ class WritableSlice : public ReadableSlice { public: friend class StorageContainer; friend class CacheService; + friend class CacheServer; /// \brief Default constructor WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} /// \brief This form of a constructor takes a pointer and its size. diff --git a/tests/ut/cpp/dataset/c_api_cache_test.cc b/tests/ut/cpp/dataset/c_api_cache_test.cc index 9b4eb6348e..04cf3a7a54 100644 --- a/tests/ut/cpp/dataset/c_api_cache_test.cc +++ b/tests/ut/cpp/dataset/c_api_cache_test.cc @@ -42,7 +42,13 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheCApiSamplerNull) { // Create an ImageFolder Dataset, this folder_path only has 2 images in it std::string folder_path = datasets_root_path_ + "/testImageNetData/train/"; std::shared_ptr ds = ImageFolder(folder_path, false, nullptr, {}, {}, some_cache); - EXPECT_EQ(ds, nullptr); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + // Now the parameter check for ImageFolderNode would fail and we would end up with a nullptr iter. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_EQ(iter, nullptr); } TEST_F(MindDataTestCacheOp, DISABLED_TestCacheImageFolderCApi) { diff --git a/tests/ut/python/cachetests/cachetest_args.sh b/tests/ut/python/cachetests/cachetest_args.sh index 7f6f3082f8..70ad5a18cf 100755 --- a/tests/ut/python/cachetests/cachetest_args.sh +++ b/tests/ut/python/cachetests/cachetest_args.sh @@ -121,7 +121,7 @@ HandleRcExit $? 0 1 # find a port that is occupied using netstat if [ -x "$(command -v netstat)" ]; then - port=$(netstat -ntp | grep -v '::' | awk '{print $4}' | grep -E '^[[:digit:]]+' | awk -F: '{print $2}' | sort -n | tail -n 1) + port=$(netstat -ntp | grep -v '::' | awk '{print $4}' | grep -E '^[[:digit:]]+' | awk -F: '{print $2}' | sort -n | tail -n 1) if [ ${port} -gt 1025 ]; then # start cache server with occupied port cmd="${CACHE_ADMIN} --start -p ${port}" @@ -171,7 +171,12 @@ HandleRcExit $? 0 0 cmd="${CACHE_ADMIN} --start -w illegal" CacheAdminCmd "${cmd}" 1 HandleRcExit $? 0 0 -cmd="${CACHE_ADMIN} --start -w 101" +num_cpu=$(grep -c processor /proc/cpuinfo) +if [ $num_cpu -lt 100 ]; then + cmd="${CACHE_ADMIN} --start -w 101" +else + cmd="${CACHE_ADMIN} --start -w ${num_cpu}+1" +fi CacheAdminCmd "${cmd}" 1 HandleRcExit $? 0 0 cmd="${CACHE_ADMIN} --start -w 9999999" diff --git a/tests/ut/python/test_server_stop_testcase.sh b/tests/ut/python/test_server_stop_testcase.sh deleted file mode 100755 index 7187e00908..0000000000 --- a/tests/ut/python/test_server_stop_testcase.sh +++ /dev/null @@ -1,10 +0,0 @@ -~/cache/cache_admin --start -session_id=$(~/cache/cache_admin -g | awk '{print $NF}') -export SESSION_ID=${session_id} -pytest dataset/test_cache_nomap.py::test_cache_nomap_server_stop & -pid=("$!") - -sleep 2 -~/cache/cache_admin --stop -sleep 1 -wait ${pid}