|
|
@ -82,7 +82,9 @@ class RequestSend final : public RequestBase {
|
|
|
|
virtual std::string GetReqName() { return request_->Varname(); }
|
|
|
|
virtual std::string GetReqName() { return request_->Varname(); }
|
|
|
|
|
|
|
|
|
|
|
|
virtual void Process() {
|
|
|
|
virtual void Process() {
|
|
|
|
queue_->Push(std::make_pair(request_->Varname(), request_));
|
|
|
|
std::string var_name = GetReqName();
|
|
|
|
|
|
|
|
VLOG(3) << "RequestSend " << var_name;
|
|
|
|
|
|
|
|
queue_->Push(std::make_pair(var_name, request_));
|
|
|
|
|
|
|
|
|
|
|
|
sendrecv::VoidMessage reply;
|
|
|
|
sendrecv::VoidMessage reply;
|
|
|
|
responder_.Finish(reply, ::grpc::Status::OK, this);
|
|
|
|
responder_.Finish(reply, ::grpc::Status::OK, this);
|
|
|
@ -106,7 +108,7 @@ class RequestGet final : public RequestBase {
|
|
|
|
responder_(&ctx_),
|
|
|
|
responder_(&ctx_),
|
|
|
|
scope_(scope),
|
|
|
|
scope_(scope),
|
|
|
|
queue_(queue) {
|
|
|
|
queue_(queue) {
|
|
|
|
int method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
|
|
|
|
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
|
|
|
|
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
|
|
|
|
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
|
|
|
|
cq_, this);
|
|
|
|
cq_, this);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -118,6 +120,7 @@ class RequestGet final : public RequestBase {
|
|
|
|
virtual void Process() {
|
|
|
|
virtual void Process() {
|
|
|
|
// proc request.
|
|
|
|
// proc request.
|
|
|
|
std::string var_name = request_.varname();
|
|
|
|
std::string var_name = request_.varname();
|
|
|
|
|
|
|
|
VLOG(3) << "RequestGet " << var_name;
|
|
|
|
auto* var = scope_->FindVar(var_name);
|
|
|
|
auto* var = scope_->FindVar(var_name);
|
|
|
|
|
|
|
|
|
|
|
|
::grpc::ByteBuffer reply;
|
|
|
|
::grpc::ByteBuffer reply;
|
|
|
@ -176,7 +179,7 @@ class RequestPrefetch final : public RequestBase {
|
|
|
|
::grpc::ByteBuffer reply;
|
|
|
|
::grpc::ByteBuffer reply;
|
|
|
|
|
|
|
|
|
|
|
|
std::string var_name = request_->OutVarname();
|
|
|
|
std::string var_name = request_->OutVarname();
|
|
|
|
VLOG(3) << "prefetch var " << var_name;
|
|
|
|
VLOG(3) << "RequestPrefetch " << var_name;
|
|
|
|
auto var_desc = program_->Block(0).FindVar(var_name);
|
|
|
|
auto var_desc = program_->Block(0).FindVar(var_name);
|
|
|
|
framework::Scope* local_scope = &scope_->NewScope();
|
|
|
|
framework::Scope* local_scope = &scope_->NewScope();
|
|
|
|
auto* var = local_scope->FindVar(var_name);
|
|
|
|
auto* var = local_scope->FindVar(var_name);
|
|
|
@ -307,18 +310,21 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
|
|
|
|
bool ok = false;
|
|
|
|
bool ok = false;
|
|
|
|
|
|
|
|
|
|
|
|
while (true) {
|
|
|
|
while (true) {
|
|
|
|
VLOG(3) << "HandleRequest for " << cq_name << " while in";
|
|
|
|
VLOG(3) << "HandleRequest for " << cq_name << " wait Next";
|
|
|
|
if (!cq->Next(&tag, &ok)) {
|
|
|
|
if (!cq->Next(&tag, &ok)) {
|
|
|
|
LOG(INFO) << cq_name << " CompletionQueue shutdown!";
|
|
|
|
LOG(INFO) << cq_name << " CompletionQueue shutdown!";
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
|
|
|
|
VLOG(3) << "HandleRequest for " << cq_name << " get Next";
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(tag);
|
|
|
|
PADDLE_ENFORCE(tag);
|
|
|
|
|
|
|
|
|
|
|
|
if (sync_mode_) {
|
|
|
|
if (sync_mode_) {
|
|
|
|
// FIXME(typhoonzero): de-couple the barriers with recv_op
|
|
|
|
// FIXME(typhoonzero): de-couple the barriers with recv_op
|
|
|
|
|
|
|
|
VLOG(3) << "HandleRequest for " << cq_name << " before WaitCond";
|
|
|
|
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
|
|
|
|
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
|
|
|
|
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
|
|
|
|
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
|
|
|
|
|
|
|
|
VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
RequestBase* base = reinterpret_cast<RequestBase*>(tag);
|
|
|
|
RequestBase* base = reinterpret_cast<RequestBase*>(tag);
|
|
|
@ -353,8 +359,10 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
|
|
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::WaitCond(int cond) {
|
|
|
|
void AsyncGRPCServer::WaitCond(int cond) {
|
|
|
|
std::unique_lock<std::mutex> lock(this->barrier_mutex_);
|
|
|
|
std::unique_lock<std::mutex> lock(this->barrier_mutex_);
|
|
|
|
|
|
|
|
VLOG(3) << "WaitCond " << cond << " in";
|
|
|
|
barrier_condition_.wait(lock,
|
|
|
|
barrier_condition_.wait(lock,
|
|
|
|
[=] { return this->barrier_cond_step_ == cond; });
|
|
|
|
[=] { return this->barrier_cond_step_ == cond; });
|
|
|
|
|
|
|
|
VLOG(3) << "WaitCond " << cond << " out";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::SetCond(int cond) {
|
|
|
|
void AsyncGRPCServer::SetCond(int cond) {
|
|
|
|