|
|
|
@ -185,6 +185,37 @@ class RequestPrefetch final : public RequestBase {
|
|
|
|
|
framework::Scope* local_scope_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class RequestCheckpointNotify final : public RequestBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit RequestCheckpointNotify(GrpcService::AsyncService* service,
|
|
|
|
|
::grpc::ServerCompletionQueue* cq,
|
|
|
|
|
RequestHandler* request_handler, int req_id)
|
|
|
|
|
: RequestBase(service, cq, request_handler, req_id),
|
|
|
|
|
responder_(&ctx_),
|
|
|
|
|
local_scope_(nullptr) {
|
|
|
|
|
request_.reset(new VariableResponse(request_handler->scope(),
|
|
|
|
|
request_handler->dev_ctx(), true));
|
|
|
|
|
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
|
|
|
|
|
service_->RequestAsyncUnary(
|
|
|
|
|
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
|
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual ~RequestCheckpointNotify() {}
|
|
|
|
|
|
|
|
|
|
std::string GetReqName() override { return request_->Varname(); }
|
|
|
|
|
|
|
|
|
|
void Process() override {
|
|
|
|
|
auto scope = request_->GetMutableLocalScope();
|
|
|
|
|
std::string nullptr_str = nullptr;
|
|
|
|
|
framework::Variable* invar = nullptr;
|
|
|
|
|
framework::Variable* outvar = nullptr;
|
|
|
|
|
|
|
|
|
|
request_handler_->Handle(nullptr_str, scope, invar, &outvar, nullptr_str);
|
|
|
|
|
Finish(reply_, &responder_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::WaitServerReady() {
|
|
|
|
|
VLOG(3) << "AsyncGRPCServer is wait server ready";
|
|
|
|
|
std::unique_lock<std::mutex> lock(this->mutex_ready_);
|
|
|
|
@ -288,6 +319,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
|
|
|
|
|
b = new RequestGet(&service_, cq.get(), handler, req_id);
|
|
|
|
|
} else if (rpc_name == kRequestPrefetch) {
|
|
|
|
|
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
|
|
|
|
|
} else if (rpc_name == kRequestCheckpoint) {
|
|
|
|
|
b = new RequestCheckpoin
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE(false, "not supported rpc");
|
|
|
|
|
}
|
|
|
|
|