diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc index d67ba30efc..93f92e6b97 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc @@ -99,6 +99,7 @@ ds::Status StartServer(int argc, char **argv) { std::cout << "\nRecommendation:\nSince the server is detached into its own daemon process, monitor the server " "logs (under " << ds::DefaultLogDir() << ") for any issues that may happen after startup\n"; + MS_LOG(INFO) << "Cache server has started successfully and is listening on port " << port << std::endl; signal(SIGCHLD, SIG_IGN); // ignore sig child signal. return ds::Status::OK(); } else { diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc index c4c57132b6..29bf177cad 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc @@ -362,7 +362,7 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) { CacheService *cs = GetService(connection_id); auto *base = SharedMemoryBaseAddr(); // Ensure we got 3 pieces of data coming in - CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() == 3, "Incomplete data"); + CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= 3, "Incomplete data"); // First piece of data is the cookie and is required auto &cookie = rq->buf_data(0); // Second piece of data is the address where we can find the serialized data @@ -400,11 +400,10 @@ Status CacheServer::BatchFetch(const std::shared_ptr distribution(0, numQ - 1); int32_t qID = distribution(rng); - std::vector cache_rq_list; auto p = flatbuffers::GetRoot(fbb->GetBufferPointer()); const auto num_elements = p->rows()->size(); auto connection_id = p->connection_id(); - cache_rq_list.reserve(num_elements); + BatchWait batch_wait = BatchWait(num_elements); int64_t data_offset = (num_elements + 1) * sizeof(int64_t); auto *offset_array = reinterpret_cast(out->GetMutablePointer()); offset_array[0] = data_offset; @@ -425,7 +424,6 @@ Status CacheServer::BatchFetch(const std::shared_ptrtype_ = BaseRequest::RequestType::kInternalFetchRow; cache_rq->st_ = CacheServerRequest::STATE::PROCESS; @@ -441,19 +439,17 @@ Status CacheServer::BatchFetch(const std::shared_ptrrq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize()); + cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast(&batch_wait))); RETURN_IF_NOT_OK(PushRequest(worker_id, cache_rq)); + } else { + // Nothing to fetch but we still need to post something back into the wait area. + RETURN_IF_NOT_OK(batch_wait.Set(Status::OK())); } } // Now wait for all of them to come back. - Status rc; - for (CacheServerRequest *rq : cache_rq_list) { - RETURN_IF_NOT_OK(rq->Wait()); - if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) { - rc = rq->rc_; - } - RETURN_IF_NOT_OK(ReturnRequestTag(rq)); - } - return rc; + RETURN_IF_NOT_OK(batch_wait.Wait()); + // Return the result + return batch_wait.GetRc(); } Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) { @@ -727,7 +723,6 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) { auto rng = GetRandomDevice(); std::uniform_int_distribution distribution(0, numQ - 1); int32_t qID = distribution(rng); - std::vector cache_rq_list; try { auto &cookie = rq->buf_data(0); auto connection_id = rq->connection_id(); @@ -738,7 +733,7 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) { offset_addr = strtoll(rq->buf_data(1).data(), nullptr, 10); auto p = reinterpret_cast(reinterpret_cast(base) + offset_addr); num_elem = strtol(rq->buf_data(2).data(), nullptr, 10); - cache_rq_list.reserve(num_elem); + BatchWait batch_wait = BatchWait(num_elem); // Get a set of free request and push into the queues. for (auto i = 0; i < num_elem; ++i) { auto start = reinterpret_cast(p); @@ -749,7 +744,6 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) { } CacheServerRequest *cache_rq; RETURN_IF_NOT_OK(GetFreeRequestTag(qID++ % numQ, &cache_rq)); - cache_rq_list.push_back(cache_rq); // Fill in details. cache_rq->type_ = BaseRequest::RequestType::kInternalCacheRow; cache_rq->st_ = CacheServerRequest::STATE::PROCESS; @@ -760,25 +754,20 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) { cache_rq->rq_.add_buf_data(cookie); cache_rq->rq_.add_buf_data(std::to_string(start - reinterpret_cast(base))); cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast(p - start))); + cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast(&batch_wait))); RETURN_IF_NOT_OK(PushRequest(GetRandomWorker(), cache_rq)); } // Now wait for all of them to come back. - Status rc; - for (CacheServerRequest *cache_rq : cache_rq_list) { - RETURN_IF_NOT_OK(cache_rq->Wait()); - if (cache_rq->rc_.IsError() && !cache_rq->rc_.IsInterrupted() && rc.IsOk()) { - rc = cache_rq->rc_; - } - RETURN_IF_NOT_OK(ReturnRequestTag(cache_rq)); - } - return rc; + RETURN_IF_NOT_OK(batch_wait.Wait()); + // Return the result + return batch_wait.GetRc(); } catch (const std::exception &e) { RETURN_STATUS_UNEXPECTED(e.what()); } return Status::OK(); } -void CacheServer::ProcessRequest(CacheServerRequest *cache_req) { +Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) { bool internal_request = false; auto &rq = cache_req->rq_; auto &reply = cache_req->reply_; @@ -792,6 +781,17 @@ void CacheServer::ProcessRequest(CacheServerRequest *cache_req) { if (BitTest(flag, kDataIsInSharedMemory)) { cache_req->rc_ = FastCacheRow(&rq, &reply); internal_request = (cache_req->type_ == BaseRequest::RequestType::kInternalCacheRow); + if (internal_request) { + // This is an internal request and is not tied to rpc. But need to post because there + // is a thread waiting on the completion of this request. + try { + int64_t addr = strtol(rq.buf_data(3).data(), nullptr, 10); + auto *bw = reinterpret_cast(addr); + RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_))); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + } } else { cache_req->rc_ = CacheRow(&rq, &reply); } @@ -815,6 +815,15 @@ void CacheServer::ProcessRequest(CacheServerRequest *cache_req) { cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); } else { cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot(rq.buf_data(0).data())); + // This is an internal request and is not tied to rpc. But need to post because there + // is a thread waiting on the completion of this request. + try { + int64_t addr = strtol(rq.buf_data(1).data(), nullptr, 10); + auto *bw = reinterpret_cast(addr); + RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_))); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } } break; } @@ -912,10 +921,10 @@ void CacheServer::ProcessRequest(CacheServerRequest *cache_req) { if (!internal_request) { cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req); } else { - // This is an internal request and is not tied to rpc. But need to post because there - // is a thread waiting on the completion of this request. - cache_req->wp_.Set(); + // We can free up the request now. + RETURN_IF_NOT_OK(ReturnRequestTag(cache_req)); } + return Status::OK(); } /// \brief This is the main loop the cache server thread(s) are running. @@ -929,7 +938,7 @@ Status CacheServer::ServerRequest(worker_id_t worker_id) { while (!global_shutdown_) { CacheServerRequest *cache_req = nullptr; RETURN_IF_NOT_OK(my_que->PopFront(&cache_req)); - ProcessRequest(cache_req); + RETURN_IF_NOT_OK(ProcessRequest(cache_req)); } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h index b6ac0d0f8e..0c25c51260 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h @@ -204,7 +204,7 @@ class CacheServer : public Service { /// \brief How a request is handled. /// \note that it can be process immediately by a grpc thread or routed to a server thread /// which is pinned to some numa node core. - void ProcessRequest(CacheServerRequest *cache_req); + Status ProcessRequest(CacheServerRequest *cache_req); void GlobalShutdown(); @@ -351,6 +351,52 @@ class CacheServer : public Service { /// \brief Connect request by a pipeline Status ConnectReset(CacheRequest *rq); + /// \brief This is an internal structure used by Batch processing. + /// This is how it works internally. For batch fetch/cache, the grpc thread + /// will intercept the request and breaks it down into multiple internal requests + /// and spread over all the server workers. But each internal request consumes + /// one free tag and we may run out of free tags if they don't return promptly. + /// So we will let the server thread return the free tag immediately but the put + /// the return code in this following structure. GRPC thread must wait until all + /// the rc come back. + class BatchWait { + public: + explicit BatchWait(int n) : expected_(n), num_back_(0) { + expected_ = n; + rc_lists_.reserve(expected_); + } + + Status Set(Status rc) { + CHECK_FAIL_RETURN_UNEXPECTED(expected_ > num_back_, "Programming error"); + std::unique_lock lck(mux_); + rc_lists_.push_back(std::move(rc)); + ++num_back_; + if (num_back_ == expected_) { + wp_.Set(); + } + return Status::OK(); + } + + Status Wait() { return wp_.Wait(); } + + Status GetRc() { + Status rc; + for (auto &cache_rc : rc_lists_) { + if (cache_rc.IsError() && !cache_rc.IsInterrupted() && rc.IsOk()) { + rc = cache_rc; + } + } + return rc; + } + + private: + std::mutex mux_; + WaitPost wp_; + int64_t expected_; + int64_t num_back_; + std::vector rc_lists_; + }; + /// \brief Internal function to do row batch fetch /// \param rq Request /// \param reply Reply diff --git a/tests/ut/cpp/dataset/c_api_cache_test.cc b/tests/ut/cpp/dataset/c_api_cache_test.cc index b2b18c1037..a769c048f7 100644 --- a/tests/ut/cpp/dataset/c_api_cache_test.cc +++ b/tests/ut/cpp/dataset/c_api_cache_test.cc @@ -796,43 +796,6 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCApiCacheShareFailure1) { std::shared_ptr some_cache = CreateDatasetCache(env_session, 0, true); EXPECT_NE(some_cache, nullptr); - // Create an ImageFolder Dataset, this folder_path only has 2 images in it - std::string folder_path = datasets_root_path_ + "/testImageNetData/train/"; - std::shared_ptr ds1 = ImageFolder(folder_path, true, RandomSampler(), {}, {}, some_cache); - EXPECT_NE(ds1, nullptr); - std::shared_ptr ds2 = ImageFolder(folder_path, true, SequentialSampler(), {}, {}, some_cache); - EXPECT_NE(ds2, nullptr); - - // Create and launch the Execution Tree for ds1 - std::shared_ptr iter1 = ds1->CreateIterator(); - EXPECT_NE(iter1, nullptr); - // Iterate the dataset and get each row - std::unordered_map> row; - iter1->GetNextRow(&row); - - uint64_t i = 0; - while (row.size() != 0) { - i++; - auto image = row["image"]; - iter1->GetNextRow(&row); - } - EXPECT_EQ(i, 2); - // Manually terminate the pipeline - iter1->Stop(); - - // Re-use a cache for the second pipeline would fail - std::shared_ptr iter2 = ds2->CreateIterator(); - EXPECT_EQ(iter2, nullptr); -} - -TEST_F(MindDataTestCacheOp, DISABLED_TestCApiCacheShareFailure2) { - session_id_type env_session; - Status s = GetSessionFromEnv(&env_session); - EXPECT_EQ(s, Status::OK()); - - std::shared_ptr some_cache = CreateDatasetCache(env_session, 0, true); - EXPECT_NE(some_cache, nullptr); - // Create an ImageFolder Dataset, this folder_path only has 2 images in it std::string folder_path = datasets_root_path_ + "/testImageNetData/train/"; std::shared_ptr ds1 = ImageFolder(folder_path, true, RandomSampler(), {}, {}, some_cache); diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index e6d7d023bc..b6d69210f4 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -1820,6 +1820,43 @@ def test_cache_map_cifar2(): logger.info("test_cache_map_cifar2 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_cifar3(): + """ + Test mappable cifar10 leaf with the cache op later in the tree above the map(resize) + In this case, we set a extra-small size for cache (size=1) and there are 10000 rows in the dataset. + + cache + | + Map(resize) + | + Cifar100 + """ + + logger.info("Test cache map cifar3") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False) + + ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR) + resize_op = c_vision.Resize((224, 224)) + ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) + + num_epoch = 2 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 10000 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_cifar3 Ended.\n") + + @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_voc1(): """