|
|
|
@ -36,7 +36,10 @@ class RequestBase {
|
|
|
|
|
|
|
|
|
|
CallStatus Status() { return status_; }
|
|
|
|
|
void SetStatus(CallStatus status) { status_ = status; }
|
|
|
|
|
virtual std::string GetReqName() { assert(false); }
|
|
|
|
|
virtual std::string GetReqName() {
|
|
|
|
|
assert(false);
|
|
|
|
|
return "";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
grpc::ServerContext ctx_;
|
|
|
|
@ -80,11 +83,13 @@ class RequestGet final : public RequestBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
|
|
|
|
|
grpc::ServerCompletionQueue* cq, framework::Scope* scope,
|
|
|
|
|
const platform::DeviceContext* dev_ctx)
|
|
|
|
|
const platform::DeviceContext* dev_ctx,
|
|
|
|
|
SimpleBlockQueue<char>* queue)
|
|
|
|
|
: RequestBase(service, cq),
|
|
|
|
|
responder_(&ctx_),
|
|
|
|
|
scope_(scope),
|
|
|
|
|
dev_ctx_(dev_ctx) {
|
|
|
|
|
dev_ctx_(dev_ctx),
|
|
|
|
|
queue_(queue) {
|
|
|
|
|
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -100,6 +105,7 @@ class RequestGet final : public RequestBase {
|
|
|
|
|
// TODO(gongwb): check var's info.
|
|
|
|
|
responder_.Finish(reply_, grpc::Status::OK, this);
|
|
|
|
|
status_ = FINISH;
|
|
|
|
|
queue_->Push('c');
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
@ -108,8 +114,15 @@ class RequestGet final : public RequestBase {
|
|
|
|
|
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
|
|
|
|
|
framework::Scope* scope_;
|
|
|
|
|
const platform::DeviceContext* dev_ctx_;
|
|
|
|
|
SimpleBlockQueue<char>* queue_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::WaitClientGet(int count) {
|
|
|
|
|
for (int i = 0; i < count; ++i) {
|
|
|
|
|
var_get_queue_.Pop();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::RunSyncUpdate() {
|
|
|
|
|
grpc::ServerBuilder builder;
|
|
|
|
|
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
|
|
|
|
@ -170,7 +183,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
|
|
|
|
|
if (is_shut_down_) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_);
|
|
|
|
|
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
|
|
|
|
|
&var_get_queue_);
|
|
|
|
|
VLOG(4) << "create Requestget status:" << get->Status();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -188,9 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(tag);
|
|
|
|
|
if (wait && !done_) {
|
|
|
|
|
Wait();
|
|
|
|
|
}
|
|
|
|
|
if (cq_name == "cq_get") WaitCond(2);
|
|
|
|
|
if (cq_name == "cq_send") WaitCond(0);
|
|
|
|
|
|
|
|
|
|
RequestBase* base = (RequestBase*)tag;
|
|
|
|
|
// reference:
|
|
|
|
@ -222,22 +235,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::Wait() {
|
|
|
|
|
std::unique_lock<std::mutex> lock(this->mutex_);
|
|
|
|
|
condition_.wait(lock, [=] { return this->done_ == true; });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::Reset() {
|
|
|
|
|
std::lock_guard<std::mutex> lock(this->mutex_);
|
|
|
|
|
done_ = false;
|
|
|
|
|
void AsyncGRPCServer::WaitCond(int cond) {
|
|
|
|
|
std::unique_lock<std::mutex> lock(this->barrier_mutex_);
|
|
|
|
|
barrier_condition_.wait(lock,
|
|
|
|
|
[=] { return this->barrier_cond_step_ == cond; });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::Done() {
|
|
|
|
|
void AsyncGRPCServer::SetCond(int cond) {
|
|
|
|
|
{
|
|
|
|
|
std::lock_guard<std::mutex> lock(this->mutex_);
|
|
|
|
|
done_ = true;
|
|
|
|
|
std::lock_guard<std::mutex> lock(this->barrier_mutex_);
|
|
|
|
|
barrier_cond_step_ = cond;
|
|
|
|
|
}
|
|
|
|
|
condition_.notify_all();
|
|
|
|
|
barrier_condition_.notify_all();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace detail
|
|
|
|
|