|
|
|
@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH };
|
|
|
|
|
class RequestBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit RequestBase(GrpcService::AsyncService* service,
|
|
|
|
|
::grpc::ServerCompletionQueue* cq,
|
|
|
|
|
::grpc::ServerCompletionQueue* cq, bool sync_mode,
|
|
|
|
|
const platform::DeviceContext* dev_ctx)
|
|
|
|
|
: service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) {
|
|
|
|
|
: service_(service),
|
|
|
|
|
cq_(cq),
|
|
|
|
|
sync_mode_(sync_mode),
|
|
|
|
|
status_(PROCESS),
|
|
|
|
|
dev_ctx_(dev_ctx) {
|
|
|
|
|
PADDLE_ENFORCE(cq_);
|
|
|
|
|
}
|
|
|
|
|
virtual ~RequestBase() {}
|
|
|
|
@ -49,6 +53,7 @@ class RequestBase {
|
|
|
|
|
::grpc::ServerContext ctx_;
|
|
|
|
|
GrpcService::AsyncService* service_;
|
|
|
|
|
::grpc::ServerCompletionQueue* cq_;
|
|
|
|
|
const bool sync_mode_;
|
|
|
|
|
CallStatus status_;
|
|
|
|
|
const platform::DeviceContext* dev_ctx_;
|
|
|
|
|
};
|
|
|
|
@ -56,11 +61,17 @@ class RequestBase {
|
|
|
|
|
class RequestSend final : public RequestBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit RequestSend(GrpcService::AsyncService* service,
|
|
|
|
|
::grpc::ServerCompletionQueue* cq,
|
|
|
|
|
::grpc::ServerCompletionQueue* cq, bool sync_mode,
|
|
|
|
|
framework::Scope* scope, ReceivedQueue* queue,
|
|
|
|
|
const platform::DeviceContext* dev_ctx)
|
|
|
|
|
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
|
|
|
|
|
request_.reset(new VariableResponse(false, scope, dev_ctx_));
|
|
|
|
|
: RequestBase(service, cq, sync_mode, dev_ctx),
|
|
|
|
|
queue_(queue),
|
|
|
|
|
responder_(&ctx_) {
|
|
|
|
|
if (sync_mode_) {
|
|
|
|
|
request_.reset(new VariableResponse(false, scope, dev_ctx_));
|
|
|
|
|
} else {
|
|
|
|
|
request_.reset(new VariableResponse(true, scope, dev_ctx_));
|
|
|
|
|
}
|
|
|
|
|
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
|
|
|
|
|
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
|
|
|
|
|
cq_, cq_, this);
|
|
|
|
@ -87,11 +98,11 @@ class RequestSend final : public RequestBase {
|
|
|
|
|
class RequestGet final : public RequestBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit RequestGet(GrpcService::AsyncService* service,
|
|
|
|
|
::grpc::ServerCompletionQueue* cq,
|
|
|
|
|
::grpc::ServerCompletionQueue* cq, bool sync_mode,
|
|
|
|
|
framework::Scope* scope,
|
|
|
|
|
const platform::DeviceContext* dev_ctx,
|
|
|
|
|
SimpleBlockQueue<MessageWithName>* queue)
|
|
|
|
|
: RequestBase(service, cq, dev_ctx),
|
|
|
|
|
: RequestBase(service, cq, sync_mode, dev_ctx),
|
|
|
|
|
responder_(&ctx_),
|
|
|
|
|
scope_(scope),
|
|
|
|
|
queue_(queue) {
|
|
|
|
@ -134,19 +145,23 @@ class RequestGet final : public RequestBase {
|
|
|
|
|
class RequestPrefetch final : public RequestBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit RequestPrefetch(GrpcService::AsyncService* service,
|
|
|
|
|
::grpc::ServerCompletionQueue* cq,
|
|
|
|
|
::grpc::ServerCompletionQueue* cq, bool sync_mode,
|
|
|
|
|
framework::Scope* scope,
|
|
|
|
|
const platform::DeviceContext* dev_ctx,
|
|
|
|
|
framework::Executor* executor,
|
|
|
|
|
framework::ProgramDesc* program,
|
|
|
|
|
framework::ExecutorPrepareContext* prefetch_ctx)
|
|
|
|
|
: RequestBase(service, cq, dev_ctx),
|
|
|
|
|
: RequestBase(service, cq, sync_mode, dev_ctx),
|
|
|
|
|
responder_(&ctx_),
|
|
|
|
|
scope_(scope),
|
|
|
|
|
executor_(executor),
|
|
|
|
|
program_(program),
|
|
|
|
|
prefetch_ctx_(prefetch_ctx) {
|
|
|
|
|
request_.reset(new VariableResponse(false, scope, dev_ctx_));
|
|
|
|
|
if (sync_mode_) {
|
|
|
|
|
request_.reset(new VariableResponse(false, scope, dev_ctx_));
|
|
|
|
|
} else {
|
|
|
|
|
request_.reset(new VariableResponse(true, scope, dev_ctx_));
|
|
|
|
|
}
|
|
|
|
|
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
|
|
|
|
|
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
|
|
|
|
|
cq_, cq_, this);
|
|
|
|
@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase {
|
|
|
|
|
framework::Executor* executor_;
|
|
|
|
|
framework::ProgramDesc* program_;
|
|
|
|
|
framework::ExecutorPrepareContext* prefetch_ctx_;
|
|
|
|
|
int blkid_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::WaitClientGet(int count) {
|
|
|
|
@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
|
|
|
|
|
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
|
|
|
|
|
&var_recv_queue_, dev_ctx_);
|
|
|
|
|
RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
|
|
|
|
|
scope_, &var_recv_queue_, dev_ctx_);
|
|
|
|
|
VLOG(4) << "Create RequestSend status:" << send->Status();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
|
|
|
|
|
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
|
|
|
|
|
&var_get_queue_);
|
|
|
|
|
RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
|
|
|
|
|
dev_ctx_, &var_get_queue_);
|
|
|
|
|
VLOG(4) << "Create RequestGet status:" << get->Status();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
RequestPrefetch* prefetch =
|
|
|
|
|
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
|
|
|
|
|
executor_, program_, prefetch_ctx_);
|
|
|
|
|
new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_,
|
|
|
|
|
dev_ctx_, executor_, program_, prefetch_ctx_);
|
|
|
|
|
|
|
|
|
|
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
|
|
|
|
|
}
|
|
|
|
|