|
|
@ -66,11 +66,11 @@ class RequestSend final : public RequestBase {
|
|
|
|
explicit RequestSend(GrpcService::AsyncService* service,
|
|
|
|
explicit RequestSend(GrpcService::AsyncService* service,
|
|
|
|
::grpc::ServerCompletionQueue* cq, bool sync_mode,
|
|
|
|
::grpc::ServerCompletionQueue* cq, bool sync_mode,
|
|
|
|
framework::Scope* scope, ReceivedQueue* queue,
|
|
|
|
framework::Scope* scope, ReceivedQueue* queue,
|
|
|
|
const platform::DeviceContext* dev_ctx, int i)
|
|
|
|
const platform::DeviceContext* dev_ctx, int req_id)
|
|
|
|
: RequestBase(service, cq, sync_mode, dev_ctx),
|
|
|
|
: RequestBase(service, cq, sync_mode, dev_ctx),
|
|
|
|
queue_(queue),
|
|
|
|
queue_(queue),
|
|
|
|
responder_(&ctx_),
|
|
|
|
responder_(&ctx_),
|
|
|
|
i_(i) {
|
|
|
|
req_id_(req_id) {
|
|
|
|
if (sync_mode_) {
|
|
|
|
if (sync_mode_) {
|
|
|
|
request_.reset(new VariableResponse(scope, dev_ctx_, false));
|
|
|
|
request_.reset(new VariableResponse(scope, dev_ctx_, false));
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
@ -79,7 +79,7 @@ class RequestSend final : public RequestBase {
|
|
|
|
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
|
|
|
|
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
|
|
|
|
service_->RequestAsyncUnary(
|
|
|
|
service_->RequestAsyncUnary(
|
|
|
|
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
|
|
|
|
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(i)));
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
virtual ~RequestSend() {}
|
|
|
|
virtual ~RequestSend() {}
|
|
|
@ -93,7 +93,7 @@ class RequestSend final : public RequestBase {
|
|
|
|
|
|
|
|
|
|
|
|
status_ = FINISH;
|
|
|
|
status_ = FINISH;
|
|
|
|
responder_.Finish(reply_, ::grpc::Status::OK,
|
|
|
|
responder_.Finish(reply_, ::grpc::Status::OK,
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(i_)));
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
@ -101,7 +101,7 @@ class RequestSend final : public RequestBase {
|
|
|
|
std::shared_ptr<VariableResponse> request_;
|
|
|
|
std::shared_ptr<VariableResponse> request_;
|
|
|
|
ReceivedQueue* queue_;
|
|
|
|
ReceivedQueue* queue_;
|
|
|
|
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
|
|
|
|
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
|
|
|
|
int i_;
|
|
|
|
int req_id_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class RequestGet final : public RequestBase {
|
|
|
|
class RequestGet final : public RequestBase {
|
|
|
@ -110,16 +110,17 @@ class RequestGet final : public RequestBase {
|
|
|
|
::grpc::ServerCompletionQueue* cq, bool sync_mode,
|
|
|
|
::grpc::ServerCompletionQueue* cq, bool sync_mode,
|
|
|
|
framework::Scope* scope,
|
|
|
|
framework::Scope* scope,
|
|
|
|
const platform::DeviceContext* dev_ctx,
|
|
|
|
const platform::DeviceContext* dev_ctx,
|
|
|
|
framework::BlockingQueue<MessageWithName>* queue, int i)
|
|
|
|
framework::BlockingQueue<MessageWithName>* queue,
|
|
|
|
|
|
|
|
int req_id)
|
|
|
|
: RequestBase(service, cq, sync_mode, dev_ctx),
|
|
|
|
: RequestBase(service, cq, sync_mode, dev_ctx),
|
|
|
|
responder_(&ctx_),
|
|
|
|
responder_(&ctx_),
|
|
|
|
scope_(scope),
|
|
|
|
scope_(scope),
|
|
|
|
queue_(queue),
|
|
|
|
queue_(queue),
|
|
|
|
i_(i) {
|
|
|
|
req_id_(req_id) {
|
|
|
|
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
|
|
|
|
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
|
|
|
|
service_->RequestAsyncUnary(
|
|
|
|
service_->RequestAsyncUnary(
|
|
|
|
method_id, &ctx_, &request_, &responder_, cq_, cq_,
|
|
|
|
method_id, &ctx_, &request_, &responder_, cq_, cq_,
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(i)));
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
virtual ~RequestGet() {}
|
|
|
|
virtual ~RequestGet() {}
|
|
|
@ -138,7 +139,7 @@ class RequestGet final : public RequestBase {
|
|
|
|
|
|
|
|
|
|
|
|
status_ = FINISH;
|
|
|
|
status_ = FINISH;
|
|
|
|
responder_.Finish(reply_, ::grpc::Status::OK,
|
|
|
|
responder_.Finish(reply_, ::grpc::Status::OK,
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(i_)));
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
|
|
|
|
|
|
|
|
|
|
|
|
if (var_name == FETCH_BARRIER_MESSAGE) {
|
|
|
|
if (var_name == FETCH_BARRIER_MESSAGE) {
|
|
|
|
sendrecv::VariableMessage msg;
|
|
|
|
sendrecv::VariableMessage msg;
|
|
|
@ -153,7 +154,7 @@ class RequestGet final : public RequestBase {
|
|
|
|
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
|
|
|
|
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
|
|
|
|
framework::Scope* scope_;
|
|
|
|
framework::Scope* scope_;
|
|
|
|
framework::BlockingQueue<MessageWithName>* queue_;
|
|
|
|
framework::BlockingQueue<MessageWithName>* queue_;
|
|
|
|
int i_;
|
|
|
|
int req_id_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class RequestPrefetch final : public RequestBase {
|
|
|
|
class RequestPrefetch final : public RequestBase {
|
|
|
@ -165,14 +166,14 @@ class RequestPrefetch final : public RequestBase {
|
|
|
|
framework::Executor* executor,
|
|
|
|
framework::Executor* executor,
|
|
|
|
framework::ProgramDesc* program,
|
|
|
|
framework::ProgramDesc* program,
|
|
|
|
framework::ExecutorPrepareContext* prefetch_ctx,
|
|
|
|
framework::ExecutorPrepareContext* prefetch_ctx,
|
|
|
|
int i)
|
|
|
|
int req_id)
|
|
|
|
: RequestBase(service, cq, sync_mode, dev_ctx),
|
|
|
|
: RequestBase(service, cq, sync_mode, dev_ctx),
|
|
|
|
responder_(&ctx_),
|
|
|
|
responder_(&ctx_),
|
|
|
|
scope_(scope),
|
|
|
|
scope_(scope),
|
|
|
|
executor_(executor),
|
|
|
|
executor_(executor),
|
|
|
|
program_(program),
|
|
|
|
program_(program),
|
|
|
|
prefetch_ctx_(prefetch_ctx),
|
|
|
|
prefetch_ctx_(prefetch_ctx),
|
|
|
|
i_(i) {
|
|
|
|
req_id_(req_id) {
|
|
|
|
if (sync_mode_) {
|
|
|
|
if (sync_mode_) {
|
|
|
|
request_.reset(new VariableResponse(scope, dev_ctx_, false));
|
|
|
|
request_.reset(new VariableResponse(scope, dev_ctx_, false));
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
@ -202,7 +203,7 @@ class RequestPrefetch final : public RequestBase {
|
|
|
|
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
|
|
|
|
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
|
|
|
|
|
|
|
|
|
|
|
|
responder_.Finish(reply, ::grpc::Status::OK,
|
|
|
|
responder_.Finish(reply, ::grpc::Status::OK,
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(i_)));
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
|
|
|
|
status_ = FINISH;
|
|
|
|
status_ = FINISH;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -213,7 +214,7 @@ class RequestPrefetch final : public RequestBase {
|
|
|
|
framework::Executor* executor_;
|
|
|
|
framework::Executor* executor_;
|
|
|
|
framework::ProgramDesc* program_;
|
|
|
|
framework::ProgramDesc* program_;
|
|
|
|
framework::ExecutorPrepareContext* prefetch_ctx_;
|
|
|
|
framework::ExecutorPrepareContext* prefetch_ctx_;
|
|
|
|
int i_;
|
|
|
|
int req_id_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::WaitClientGet(int count) {
|
|
|
|
void AsyncGRPCServer::WaitClientGet(int count) {
|
|
|
@ -291,21 +292,6 @@ void AsyncGRPCServer::RunSyncUpdate() {
|
|
|
|
for (int i = 0; i < kNumHandleGetThreads; ++i) {
|
|
|
|
for (int i = 0; i < kNumHandleGetThreads; ++i) {
|
|
|
|
t_gets_[i]->join();
|
|
|
|
t_gets_[i]->join();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
{
|
|
|
|
|
|
|
|
std::lock_guard<std::mutex> l(cq_mutex_);
|
|
|
|
|
|
|
|
for (int i = 0; i < kSendReqsBufSize; ++i) {
|
|
|
|
|
|
|
|
if (send_reqs_[i]) {
|
|
|
|
|
|
|
|
delete send_reqs_[i];
|
|
|
|
|
|
|
|
send_reqs_[i] = nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int i = 0; i < kGetReqsBufSize; ++i) {
|
|
|
|
|
|
|
|
if (get_reqs_[i]) {
|
|
|
|
|
|
|
|
delete get_reqs_[i];
|
|
|
|
|
|
|
|
get_reqs_[i] = nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
t_prefetch_->join();
|
|
|
|
t_prefetch_->join();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -335,19 +321,19 @@ void AsyncGRPCServer::TryToRegisterNewSendOne(int i) {
|
|
|
|
VLOG(4) << "Create RequestSend status:" << send->Status();
|
|
|
|
VLOG(4) << "Create RequestSend status:" << send->Status();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::TryToRegisterNewGetOne(int i) {
|
|
|
|
void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) {
|
|
|
|
std::unique_lock<std::mutex> lock(cq_mutex_);
|
|
|
|
std::unique_lock<std::mutex> lock(cq_mutex_);
|
|
|
|
if (is_shut_down_) {
|
|
|
|
if (is_shut_down_) {
|
|
|
|
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
|
|
|
|
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
|
|
|
|
RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
|
|
|
|
dev_ctx_, &var_get_queue_, i);
|
|
|
|
dev_ctx_, &var_get_queue_, req_id);
|
|
|
|
get_reqs_[i] = static_cast<RequestBase*>(get);
|
|
|
|
get_reqs_[req_id] = static_cast<RequestBase*>(get);
|
|
|
|
VLOG(4) << "Create RequestGet status:" << get->Status();
|
|
|
|
VLOG(4) << "Create RequestGet status:" << get->Status();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) {
|
|
|
|
void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) {
|
|
|
|
std::unique_lock<std::mutex> lock(cq_mutex_);
|
|
|
|
std::unique_lock<std::mutex> lock(cq_mutex_);
|
|
|
|
if (is_shut_down_) {
|
|
|
|
if (is_shut_down_) {
|
|
|
|
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
|
|
|
|
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
|
|
|
@ -355,7 +341,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
RequestPrefetch* prefetch = new RequestPrefetch(
|
|
|
|
RequestPrefetch* prefetch = new RequestPrefetch(
|
|
|
|
&service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
|
|
|
|
&service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
|
|
|
|
program_, prefetch_ctx_.get(), i);
|
|
|
|
program_, prefetch_ctx_.get(), req_id);
|
|
|
|
|
|
|
|
|
|
|
|
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
|
|
|
|
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -374,7 +360,7 @@ void AsyncGRPCServer::HandleRequest(
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
VLOG(3) << "HandleRequest for " << cq_name << " get Next";
|
|
|
|
VLOG(3) << "HandleRequest for " << cq_name << " get Next";
|
|
|
|
int i = static_cast<int>(reinterpret_cast<intptr_t>(tag));
|
|
|
|
int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
|
|
|
|
|
|
|
|
|
|
|
|
if (sync_mode_) {
|
|
|
|
if (sync_mode_) {
|
|
|
|
// FIXME(typhoonzero): de-couple the barriers with recv_op
|
|
|
|
// FIXME(typhoonzero): de-couple the barriers with recv_op
|
|
|
@ -387,9 +373,9 @@ void AsyncGRPCServer::HandleRequest(
|
|
|
|
{
|
|
|
|
{
|
|
|
|
std::lock_guard<std::mutex> l(cq_mutex_);
|
|
|
|
std::lock_guard<std::mutex> l(cq_mutex_);
|
|
|
|
if (cq_name == "cq_get") {
|
|
|
|
if (cq_name == "cq_get") {
|
|
|
|
base = get_reqs_[i];
|
|
|
|
base = get_reqs_[req_id];
|
|
|
|
} else if (cq_name == "cq_send") {
|
|
|
|
} else if (cq_name == "cq_send") {
|
|
|
|
base = send_reqs_[i];
|
|
|
|
base = send_reqs_[req_id];
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
CHECK(false);
|
|
|
|
CHECK(false);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -401,7 +387,7 @@ void AsyncGRPCServer::HandleRequest(
|
|
|
|
if (!ok) {
|
|
|
|
if (!ok) {
|
|
|
|
LOG(WARNING) << cq_name << " recv no regular event:argument name["
|
|
|
|
LOG(WARNING) << cq_name << " recv no regular event:argument name["
|
|
|
|
<< base->GetReqName() << "]";
|
|
|
|
<< base->GetReqName() << "]";
|
|
|
|
TryToRegisterNewOne(i);
|
|
|
|
TryToRegisterNewOne(req_id);
|
|
|
|
delete base;
|
|
|
|
delete base;
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -413,7 +399,7 @@ void AsyncGRPCServer::HandleRequest(
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
case FINISH: {
|
|
|
|
case FINISH: {
|
|
|
|
TryToRegisterNewOne(i);
|
|
|
|
TryToRegisterNewOne(req_id);
|
|
|
|
VLOG(4) << cq_name << " FINISH status:" << base->Status();
|
|
|
|
VLOG(4) << cq_name << " FINISH status:" << base->Status();
|
|
|
|
delete base;
|
|
|
|
delete base;
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|