fix grpc_server_test

wangkuiyi-patch-1
qiaolongfei 7 years ago
parent 4e36c0ecab
commit 8fb78f6c07

@ -158,13 +158,14 @@ class RequestPrefetch final : public RequestBase {
std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname();
VLOG(3) << "in_var_name: " << in_var_name
<< "out_var_name: " << out_var_name
<< " RequestPrefetch: " << out_var_name;
auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);
framework::Variable* outvar = nullptr;
framework::Variable* outvar = scope->FindVar(out_var_name);
request_handler_->Handle(in_var_name, scope, invar, &outvar);
request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name);
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_);
@ -284,7 +285,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
} else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
} else {
PADDLE_ENFORCE(false, "not surpported rpc");
PADDLE_ENFORCE(false, "not supported rpc");
}
reqs[req_id] = b;

@ -96,8 +96,8 @@ class RequestHandler {
// *request_handler_->dev_ctx(), &reply_);
// }
virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var,
framework::Variable** outvar) = 0;
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") = 0;
protected:
const bool sync_mode_;

@ -33,7 +33,8 @@ namespace detail {
bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
framework::Variable** outvar,
const std::string& out_var_name) {
VLOG(4) << "RequestSendHandler:" << varname;
// Async
@ -82,7 +83,8 @@ void RequestSendHandler::ResetSparseVarRecorder() {
bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
framework::Variable** outvar,
const std::string& out_var_name) {
VLOG(4) << "RequestGetHandler:" << varname;
if (varname != FETCH_BARRIER_MESSAGE) {
@ -105,11 +107,11 @@ bool RequestGetHandler::Handle(const std::string& varname,
bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
framework::Variable** outvar,
const std::string& out_var_name) {
VLOG(4) << "RequestPrefetchHandler " << varname;
auto var_desc = program_->Block(0).FindVar(varname);
*outvar = scope->FindVar(varname);
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);

@ -40,7 +40,8 @@ class RequestSendHandler final : public RequestHandler {
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
void ResetSparseVarRecorder();
private:
@ -53,7 +54,8 @@ class RequestGetHandler final : public RequestHandler {
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
};
class RequestPrefetchHandler final : public RequestHandler {
@ -61,7 +63,8 @@ class RequestPrefetchHandler final : public RequestHandler {
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
};
} // namespace detail

Loading…
Cancel
Save