|
|
|
@ -25,6 +25,7 @@ namespace detail {
|
|
|
|
|
namespace {
|
|
|
|
|
const int kNumHandleSendThreads = 20;
|
|
|
|
|
const int kNumHandleGetThreads = 20;
|
|
|
|
|
const int kNumHandlePrefetchThreads = 1;
|
|
|
|
|
} // namespace
|
|
|
|
|
enum CallStatus { PROCESS = 0, FINISH };
|
|
|
|
|
|
|
|
|
@ -180,8 +181,9 @@ class RequestPrefetch final : public RequestBase {
|
|
|
|
|
request_.reset(new VariableResponse(scope, dev_ctx_, true));
|
|
|
|
|
}
|
|
|
|
|
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
|
|
|
|
|
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
|
|
|
|
|
cq_, cq_, this);
|
|
|
|
|
service_->RequestAsyncUnary(
|
|
|
|
|
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
|
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual ~RequestPrefetch() {}
|
|
|
|
@ -190,7 +192,6 @@ class RequestPrefetch final : public RequestBase {
|
|
|
|
|
|
|
|
|
|
virtual void Process() {
|
|
|
|
|
// prefetch process...
|
|
|
|
|
::grpc::ByteBuffer reply;
|
|
|
|
|
|
|
|
|
|
std::string var_name = request_->OutVarname();
|
|
|
|
|
VLOG(3) << "RequestPrefetch " << var_name;
|
|
|
|
@ -200,15 +201,16 @@ class RequestPrefetch final : public RequestBase {
|
|
|
|
|
InitializeVariable(var, var_desc->GetType());
|
|
|
|
|
executor_->RunPreparedContext(prefetch_ctx_, scope_);
|
|
|
|
|
|
|
|
|
|
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
|
|
|
|
|
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
|
|
|
|
|
|
|
|
|
|
responder_.Finish(reply, ::grpc::Status::OK,
|
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
|
|
|
|
|
status_ = FINISH;
|
|
|
|
|
responder_.Finish(reply_, ::grpc::Status::OK,
|
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::shared_ptr<VariableResponse> request_;
|
|
|
|
|
::grpc::ByteBuffer reply_;
|
|
|
|
|
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
|
|
|
|
|
framework::Scope* scope_;
|
|
|
|
|
framework::Executor* executor_;
|
|
|
|
@ -262,6 +264,9 @@ void AsyncGRPCServer::RunSyncUpdate() {
|
|
|
|
|
for (int i = 0; i < kGetReqsBufSize; ++i) {
|
|
|
|
|
TryToRegisterNewGetOne(i);
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < kPrefetchReqsBufSize; ++i) {
|
|
|
|
|
TryToRegisterNewPrefetchOne(i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < kNumHandleSendThreads; ++i) {
|
|
|
|
|
t_sends_.emplace_back(
|
|
|
|
@ -273,12 +278,11 @@ void AsyncGRPCServer::RunSyncUpdate() {
|
|
|
|
|
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
|
|
|
|
|
cq_get_.get(), "cq_get", get_register)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(wuyi): Run these "HandleRequest" in thread pool
|
|
|
|
|
t_prefetch_.reset(new std::thread(
|
|
|
|
|
for (int i = 0; i < kNumHandlePrefetchThreads; ++i) {
|
|
|
|
|
t_prefetchs_.emplace_back(new std::thread(
|
|
|
|
|
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
|
|
|
|
|
"cq_prefetch", prefetch_register)));
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
std::lock_guard<std::mutex> lock(this->mutex_ready_);
|
|
|
|
|
ready_ = 1;
|
|
|
|
@ -292,7 +296,9 @@ void AsyncGRPCServer::RunSyncUpdate() {
|
|
|
|
|
for (int i = 0; i < kNumHandleGetThreads; ++i) {
|
|
|
|
|
t_gets_[i]->join();
|
|
|
|
|
}
|
|
|
|
|
t_prefetch_->join();
|
|
|
|
|
for (int i = 0; i < kNumHandlePrefetchThreads; ++i) {
|
|
|
|
|
t_prefetchs_[i]->join();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::ShutdownQueue() {
|
|
|
|
@ -342,6 +348,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) {
|
|
|
|
|
RequestPrefetch* prefetch = new RequestPrefetch(
|
|
|
|
|
&service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
|
|
|
|
|
program_, prefetch_ctx_.get(), req_id);
|
|
|
|
|
prefetch_reqs_[req_id] = static_cast<RequestBase*>(prefetch);
|
|
|
|
|
|
|
|
|
|
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
|
|
|
|
|
}
|
|
|
|
@ -376,8 +383,8 @@ void AsyncGRPCServer::HandleRequest(
|
|
|
|
|
base = get_reqs_[req_id];
|
|
|
|
|
} else if (cq_name == "cq_send") {
|
|
|
|
|
base = send_reqs_[req_id];
|
|
|
|
|
} else {
|
|
|
|
|
CHECK(false);
|
|
|
|
|
} else if (cq_name == "cq_prefetch") {
|
|
|
|
|
base = prefetch_reqs_[req_id];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// reference:
|
|
|
|
|