Fix hang on RequestFreeTag

pull/9539/head
Jesse Lee 4 years ago committed by Lixia Chen
parent b793c7b291
commit 0c8dfc68ec

@ -99,6 +99,7 @@ ds::Status StartServer(int argc, char **argv) {
std::cout << "\nRecommendation:\nSince the server is detached into its own daemon process, monitor the server "
"logs (under "
<< ds::DefaultLogDir() << ") for any issues that may happen after startup\n";
MS_LOG(INFO) << "Cache server has started successfully and is listening on port " << port << std::endl;
signal(SIGCHLD, SIG_IGN); // ignore sig child signal.
return ds::Status::OK();
} else {

@ -362,7 +362,7 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
CacheService *cs = GetService(connection_id);
auto *base = SharedMemoryBaseAddr();
// Ensure we got 3 pieces of data coming in
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() == 3, "Incomplete data");
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() >= 3, "Incomplete data");
// First piece of data is the cookie and is required
auto &cookie = rq->buf_data(0);
// Second piece of data is the address where we can find the serialized data
@ -400,11 +400,10 @@ Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuil
auto rng = GetRandomDevice();
std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1);
int32_t qID = distribution(rng);
std::vector<CacheServerRequest *> cache_rq_list;
auto p = flatbuffers::GetRoot<BatchDataLocatorMsg>(fbb->GetBufferPointer());
const auto num_elements = p->rows()->size();
auto connection_id = p->connection_id();
cache_rq_list.reserve(num_elements);
BatchWait batch_wait = BatchWait(num_elements);
int64_t data_offset = (num_elements + 1) * sizeof(int64_t);
auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer());
offset_array[0] = data_offset;
@ -425,7 +424,6 @@ Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuil
worker_id_t worker_id = IsNumaAffinityOn() ? GetWorkerByNumaId(node_id) : GetRandomWorker();
CacheServerRequest *cache_rq;
RETURN_IF_NOT_OK(GetFreeRequestTag(qID++ % numQ, &cache_rq));
cache_rq_list.push_back(cache_rq);
// Set up all the necessarily field.
cache_rq->type_ = BaseRequest::RequestType::kInternalFetchRow;
cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
@ -441,19 +439,17 @@ Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuil
auto offset = bld.Finish();
fb2.Finish(offset);
cache_rq->rq_.add_buf_data(fb2.GetBufferPointer(), fb2.GetSize());
cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(&batch_wait)));
RETURN_IF_NOT_OK(PushRequest(worker_id, cache_rq));
} else {
// Nothing to fetch but we still need to post something back into the wait area.
RETURN_IF_NOT_OK(batch_wait.Set(Status::OK()));
}
}
// Now wait for all of them to come back.
Status rc;
for (CacheServerRequest *rq : cache_rq_list) {
RETURN_IF_NOT_OK(rq->Wait());
if (rq->rc_.IsError() && !rq->rc_.IsInterrupted() && rc.IsOk()) {
rc = rq->rc_;
}
RETURN_IF_NOT_OK(ReturnRequestTag(rq));
}
return rc;
RETURN_IF_NOT_OK(batch_wait.Wait());
// Return the result
return batch_wait.GetRc();
}
Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
@ -727,7 +723,6 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) {
auto rng = GetRandomDevice();
std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1);
int32_t qID = distribution(rng);
std::vector<CacheServerRequest *> cache_rq_list;
try {
auto &cookie = rq->buf_data(0);
auto connection_id = rq->connection_id();
@ -738,7 +733,7 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) {
offset_addr = strtoll(rq->buf_data(1).data(), nullptr, 10);
auto p = reinterpret_cast<char *>(reinterpret_cast<int64_t>(base) + offset_addr);
num_elem = strtol(rq->buf_data(2).data(), nullptr, 10);
cache_rq_list.reserve(num_elem);
BatchWait batch_wait = BatchWait(num_elem);
// Get a set of free request and push into the queues.
for (auto i = 0; i < num_elem; ++i) {
auto start = reinterpret_cast<int64_t>(p);
@ -749,7 +744,6 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) {
}
CacheServerRequest *cache_rq;
RETURN_IF_NOT_OK(GetFreeRequestTag(qID++ % numQ, &cache_rq));
cache_rq_list.push_back(cache_rq);
// Fill in details.
cache_rq->type_ = BaseRequest::RequestType::kInternalCacheRow;
cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
@ -760,25 +754,20 @@ Status CacheServer::BatchCacheRows(CacheRequest *rq) {
cache_rq->rq_.add_buf_data(cookie);
cache_rq->rq_.add_buf_data(std::to_string(start - reinterpret_cast<int64_t>(base)));
cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(p - start)));
cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(&batch_wait)));
RETURN_IF_NOT_OK(PushRequest(GetRandomWorker(), cache_rq));
}
// Now wait for all of them to come back.
Status rc;
for (CacheServerRequest *cache_rq : cache_rq_list) {
RETURN_IF_NOT_OK(cache_rq->Wait());
if (cache_rq->rc_.IsError() && !cache_rq->rc_.IsInterrupted() && rc.IsOk()) {
rc = cache_rq->rc_;
}
RETURN_IF_NOT_OK(ReturnRequestTag(cache_rq));
}
return rc;
RETURN_IF_NOT_OK(batch_wait.Wait());
// Return the result
return batch_wait.GetRc();
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
return Status::OK();
}
void CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
Status CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
bool internal_request = false;
auto &rq = cache_req->rq_;
auto &reply = cache_req->reply_;
@ -792,6 +781,17 @@ void CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
if (BitTest(flag, kDataIsInSharedMemory)) {
cache_req->rc_ = FastCacheRow(&rq, &reply);
internal_request = (cache_req->type_ == BaseRequest::RequestType::kInternalCacheRow);
if (internal_request) {
// This is an internal request and is not tied to rpc. But need to post because there
// is a thread waiting on the completion of this request.
try {
int64_t addr = strtol(rq.buf_data(3).data(), nullptr, 10);
auto *bw = reinterpret_cast<BatchWait *>(addr);
RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_)));
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
}
} else {
cache_req->rc_ = CacheRow(&rq, &reply);
}
@ -815,6 +815,15 @@ void CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq.buf_data(0).data()));
// This is an internal request and is not tied to rpc. But need to post because there
// is a thread waiting on the completion of this request.
try {
int64_t addr = strtol(rq.buf_data(1).data(), nullptr, 10);
auto *bw = reinterpret_cast<BatchWait *>(addr);
RETURN_IF_NOT_OK(bw->Set(std::move(cache_req->rc_)));
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
}
break;
}
@ -912,10 +921,10 @@ void CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
if (!internal_request) {
cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req);
} else {
// This is an internal request and is not tied to rpc. But need to post because there
// is a thread waiting on the completion of this request.
cache_req->wp_.Set();
// We can free up the request now.
RETURN_IF_NOT_OK(ReturnRequestTag(cache_req));
}
return Status::OK();
}
/// \brief This is the main loop the cache server thread(s) are running.
@ -929,7 +938,7 @@ Status CacheServer::ServerRequest(worker_id_t worker_id) {
while (!global_shutdown_) {
CacheServerRequest *cache_req = nullptr;
RETURN_IF_NOT_OK(my_que->PopFront(&cache_req));
ProcessRequest(cache_req);
RETURN_IF_NOT_OK(ProcessRequest(cache_req));
}
return Status::OK();
}

@ -204,7 +204,7 @@ class CacheServer : public Service {
/// \brief How a request is handled.
/// \note that it can be process immediately by a grpc thread or routed to a server thread
/// which is pinned to some numa node core.
void ProcessRequest(CacheServerRequest *cache_req);
Status ProcessRequest(CacheServerRequest *cache_req);
void GlobalShutdown();
@ -351,6 +351,52 @@ class CacheServer : public Service {
/// \brief Connect request by a pipeline
Status ConnectReset(CacheRequest *rq);
/// \brief This is an internal structure used by Batch processing.
/// This is how it works internally. For batch fetch/cache, the grpc thread
/// will intercept the request and breaks it down into multiple internal requests
/// and spread over all the server workers. But each internal request consumes
/// one free tag and we may run out of free tags if they don't return promptly.
/// So we will let the server thread return the free tag immediately but the put
/// the return code in this following structure. GRPC thread must wait until all
/// the rc come back.
class BatchWait {
public:
explicit BatchWait(int n) : expected_(n), num_back_(0) {
expected_ = n;
rc_lists_.reserve(expected_);
}
Status Set(Status rc) {
CHECK_FAIL_RETURN_UNEXPECTED(expected_ > num_back_, "Programming error");
std::unique_lock<std::mutex> lck(mux_);
rc_lists_.push_back(std::move(rc));
++num_back_;
if (num_back_ == expected_) {
wp_.Set();
}
return Status::OK();
}
Status Wait() { return wp_.Wait(); }
Status GetRc() {
Status rc;
for (auto &cache_rc : rc_lists_) {
if (cache_rc.IsError() && !cache_rc.IsInterrupted() && rc.IsOk()) {
rc = cache_rc;
}
}
return rc;
}
private:
std::mutex mux_;
WaitPost wp_;
int64_t expected_;
int64_t num_back_;
std::vector<Status> rc_lists_;
};
/// \brief Internal function to do row batch fetch
/// \param rq Request
/// \param reply Reply

@ -796,43 +796,6 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCApiCacheShareFailure1) {
std::shared_ptr<DatasetCache> some_cache = CreateDatasetCache(env_session, 0, true);
EXPECT_NE(some_cache, nullptr);
// Create an ImageFolder Dataset, this folder_path only has 2 images in it
std::string folder_path = datasets_root_path_ + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, true, RandomSampler(), {}, {}, some_cache);
EXPECT_NE(ds1, nullptr);
std::shared_ptr<Dataset> ds2 = ImageFolder(folder_path, true, SequentialSampler(), {}, {}, some_cache);
EXPECT_NE(ds2, nullptr);
// Create and launch the Execution Tree for ds1
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
EXPECT_NE(iter1, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter1->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
iter1->GetNextRow(&row);
}
EXPECT_EQ(i, 2);
// Manually terminate the pipeline
iter1->Stop();
// Re-use a cache for the second pipeline would fail
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
EXPECT_EQ(iter2, nullptr);
}
TEST_F(MindDataTestCacheOp, DISABLED_TestCApiCacheShareFailure2) {
session_id_type env_session;
Status s = GetSessionFromEnv(&env_session);
EXPECT_EQ(s, Status::OK());
std::shared_ptr<DatasetCache> some_cache = CreateDatasetCache(env_session, 0, true);
EXPECT_NE(some_cache, nullptr);
// Create an ImageFolder Dataset, this folder_path only has 2 images in it
std::string folder_path = datasets_root_path_ + "/testImageNetData/train/";
std::shared_ptr<Dataset> ds1 = ImageFolder(folder_path, true, RandomSampler(), {}, {}, some_cache);

@ -1820,6 +1820,43 @@ def test_cache_map_cifar2():
logger.info("test_cache_map_cifar2 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_cifar3():
"""
Test mappable cifar10 leaf with the cache op later in the tree above the map(resize)
In this case, we set a extra-small size for cache (size=1) and there are 10000 rows in the dataset.
cache
|
Map(resize)
|
Cifar100
"""
logger.info("Test cache map cifar3")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR)
resize_op = c_vision.Resize((224, 224))
ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
num_epoch = 2
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 10000
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_cifar3 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_voc1():
"""

Loading…
Cancel
Save