parent
020630b7a3
commit
da3087ada1
@ -1 +1 @@
|
||||
grpc_library(sendrecvop_grpc SRCS recv_impl.cc send_impl.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
|
||||
grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
|
||||
|
@ -0,0 +1,147 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "grpc_client.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
bool RPCClient::AsyncSendVariable(const std::string& ep,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope,
|
||||
const std::string& var_name,
|
||||
int64_t time_out) {
|
||||
sendrecv::VariableMessage req;
|
||||
auto* var = scope.FindVar(var_name);
|
||||
SerializeToMessage(var_name, var, ctx, &req);
|
||||
|
||||
// varhandle
|
||||
VarHandle var_h;
|
||||
var_h.ep = ep;
|
||||
var_h.scope = &scope;
|
||||
var_h.name = var_name;
|
||||
var_h.ctx = &ctx;
|
||||
|
||||
// stub context
|
||||
auto ch = GetChannel(ep);
|
||||
SendProcessor* s = new SendProcessor(ch);
|
||||
s->Prepare(var_h, time_out);
|
||||
s->response_call_back_ = NULL;
|
||||
|
||||
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
|
||||
rpc->Finish(&s->reply_, &s->status_, (void*)s);
|
||||
|
||||
req_count_++;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ProcGetResponse(const VarHandle& var_h,
|
||||
const sendrecv::VariableMessage& ret_msg) {
|
||||
auto* outvar = var_h.scope->FindVar(var_h.name);
|
||||
|
||||
std::istringstream iss(ret_msg.serialized());
|
||||
DeserializeFromMessage(ret_msg, *var_h.ctx, outvar);
|
||||
}
|
||||
|
||||
bool RPCClient::AsyncGetVariable(const std::string& ep,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope,
|
||||
const std::string& var_name,
|
||||
int64_t time_out) {
|
||||
sendrecv::VariableMessage req;
|
||||
req.set_varname(var_name);
|
||||
|
||||
auto* var = scope.FindVar(var_name);
|
||||
SerializeToMessage(var_name, var, ctx, &req);
|
||||
|
||||
// varhandle
|
||||
VarHandle var_h;
|
||||
var_h.ep = ep;
|
||||
var_h.scope = &scope;
|
||||
var_h.name = var_name;
|
||||
var_h.ctx = &ctx;
|
||||
|
||||
// stub context
|
||||
auto ch = GetChannel(ep);
|
||||
GetProcessor* s = new GetProcessor(ch);
|
||||
s->Prepare(var_h, time_out);
|
||||
s->response_call_back_ = ProcGetResponse;
|
||||
|
||||
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
|
||||
rpc->Finish(&s->reply_, &s->status_, (void*)s);
|
||||
|
||||
req_count_++;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RPCClient::wait() {
|
||||
bool ok = true;
|
||||
|
||||
while (true) {
|
||||
if (req_count_ <= 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (!Proceed()) {
|
||||
LOG(ERROR) << "Get meets CompletionQueue error";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
||||
bool RPCClient::Proceed() {
|
||||
void* tag = NULL;
|
||||
bool ok = false;
|
||||
|
||||
// request counts.
|
||||
if (!cq_.Next(&tag, &ok)) {
|
||||
return false;
|
||||
}
|
||||
req_count_--;
|
||||
|
||||
GPR_ASSERT(ok);
|
||||
PADDLE_ENFORCE(tag);
|
||||
|
||||
// TODO(gongwb): add more retries.
|
||||
ClientBase* c = static_cast<ClientBase*>(tag);
|
||||
if (!c->status_.ok()) {
|
||||
delete c;
|
||||
return true;
|
||||
}
|
||||
|
||||
c->Process();
|
||||
delete c;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
|
||||
auto it = channels_.find(ep);
|
||||
if (it != channels_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
auto ch = std::shared_ptr<grpc::Channel>(
|
||||
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()));
|
||||
|
||||
channels_[ep] = ch;
|
||||
return ch;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,147 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <grpc++/grpc++.h>
|
||||
#include <grpc/support/log.h>
|
||||
#include <time.h>
|
||||
#include <chrono>
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/framework/data_type.h"
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/scope.h"
|
||||
#include "paddle/framework/selected_rows.h"
|
||||
#include "paddle/operators/detail/sendrecvop_utils.h"
|
||||
#include "paddle/operators/detail/simple_block_queue.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
struct VarHandle {
|
||||
std::string ep;
|
||||
const platform::DeviceContext* ctx;
|
||||
const framework::Scope* scope;
|
||||
std::string name;
|
||||
|
||||
std::string String() const {
|
||||
std::ostringstream s;
|
||||
s << "name:[" << name << "] ep:[" << ep << "]";
|
||||
return s.str();
|
||||
}
|
||||
};
|
||||
|
||||
void ProcGetResponse(const VarHandle& var_h,
|
||||
const sendrecv::VariableMessage& msg);
|
||||
|
||||
class ClientBase {
|
||||
public:
|
||||
explicit ClientBase(std::shared_ptr<grpc::Channel> ch) {
|
||||
stub_ = sendrecv::SendRecvService::NewStub(ch);
|
||||
context_ = NULL;
|
||||
}
|
||||
|
||||
virtual ~ClientBase() {}
|
||||
|
||||
virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
|
||||
context_.reset(new grpc::ClientContext());
|
||||
var_h_ = var_info;
|
||||
|
||||
std::chrono::system_clock::time_point deadline =
|
||||
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
|
||||
|
||||
context_->set_deadline(deadline);
|
||||
}
|
||||
|
||||
virtual void Process() = 0;
|
||||
|
||||
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
|
||||
std::unique_ptr<grpc::ClientContext> context_;
|
||||
grpc::Status status_;
|
||||
VarHandle var_h_;
|
||||
};
|
||||
|
||||
typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)>
|
||||
RequestSendCallBack;
|
||||
|
||||
class SendProcessor : public ClientBase {
|
||||
public:
|
||||
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
|
||||
|
||||
virtual ~SendProcessor() {}
|
||||
|
||||
virtual void Process() {
|
||||
if (response_call_back_) {
|
||||
response_call_back_(var_h_, reply_);
|
||||
}
|
||||
}
|
||||
|
||||
sendrecv::VoidMessage reply_;
|
||||
RequestSendCallBack response_call_back_ = NULL;
|
||||
};
|
||||
|
||||
typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)>
|
||||
RequestGetCallBack;
|
||||
|
||||
class GetProcessor : public ClientBase {
|
||||
public:
|
||||
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
|
||||
|
||||
virtual ~GetProcessor() {}
|
||||
|
||||
virtual void Process() {
|
||||
if (response_call_back_) {
|
||||
response_call_back_(var_h_, reply_);
|
||||
}
|
||||
}
|
||||
|
||||
sendrecv::VariableMessage reply_;
|
||||
RequestGetCallBack response_call_back_ = ProcGetResponse;
|
||||
};
|
||||
|
||||
class RPCClient {
|
||||
public:
|
||||
bool AsyncSendVariable(const std::string& ep,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope,
|
||||
const std::string& var_name,
|
||||
int64_t time_out = 600 * 1000);
|
||||
|
||||
bool AsyncGetVariable(const std::string& ep,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope,
|
||||
const std::string& var_name,
|
||||
int64_t time_out = 600 * 1000);
|
||||
bool wait();
|
||||
|
||||
private:
|
||||
bool Proceed();
|
||||
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
|
||||
|
||||
private:
|
||||
grpc::CompletionQueue cq_;
|
||||
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
|
||||
int64_t req_count_ = 0;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,237 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/detail/grpc_server.h"
|
||||
|
||||
using grpc::ServerAsyncResponseWriter;
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
enum CallStatus { PROCESS = 0, FINISH };
|
||||
|
||||
// reference:
|
||||
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
|
||||
class RequestBase {
|
||||
public:
|
||||
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service,
|
||||
grpc::ServerCompletionQueue* cq)
|
||||
: service_(service), cq_(cq), status_(PROCESS) {}
|
||||
virtual ~RequestBase() {}
|
||||
virtual void Process() { assert(false); }
|
||||
|
||||
CallStatus Status() { return status_; }
|
||||
void SetStatus(CallStatus status) { status_ = status; }
|
||||
|
||||
protected:
|
||||
grpc::ServerContext ctx_;
|
||||
sendrecv::SendRecvService::AsyncService* service_;
|
||||
grpc::ServerCompletionQueue* cq_;
|
||||
CallStatus status_;
|
||||
};
|
||||
|
||||
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
|
||||
|
||||
class RequestSend final : public RequestBase {
|
||||
public:
|
||||
explicit RequestSend(sendrecv::SendRecvService::AsyncService* service,
|
||||
grpc::ServerCompletionQueue* cq,
|
||||
SimpleBlockQueue<MessageWithName>* queue)
|
||||
: RequestBase(service, cq), queue_(queue), responder_(&ctx_) {
|
||||
service_->RequestSendVariable(&ctx_, &request_, &responder_, cq_, cq_,
|
||||
this);
|
||||
}
|
||||
|
||||
virtual ~RequestSend() {}
|
||||
|
||||
virtual void Process() {
|
||||
MessageWithName msg_with_name =
|
||||
std::make_pair(request_.varname(), std::move(request_));
|
||||
queue_->Push(std::move(msg_with_name));
|
||||
// TODO(gongwb): check var's info.
|
||||
responder_.Finish(reply_, grpc::Status::OK, this);
|
||||
}
|
||||
|
||||
protected:
|
||||
sendrecv::VariableMessage request_;
|
||||
sendrecv::VoidMessage reply_;
|
||||
SimpleBlockQueue<MessageWithName>* queue_;
|
||||
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
|
||||
};
|
||||
|
||||
class RequestGet final : public RequestBase {
|
||||
public:
|
||||
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
|
||||
grpc::ServerCompletionQueue* cq, framework::Scope* scope)
|
||||
: RequestBase(service, cq), responder_(&ctx_), scope_(scope) {
|
||||
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
|
||||
}
|
||||
|
||||
virtual ~RequestGet() {}
|
||||
|
||||
virtual void Process() {
|
||||
// proc request.
|
||||
std::string var_name = request_.varname();
|
||||
auto* var = scope_->FindVar(var_name);
|
||||
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_);
|
||||
// TODO(gongwb): check var's info.
|
||||
responder_.Finish(reply_, grpc::Status::OK, this);
|
||||
}
|
||||
|
||||
protected:
|
||||
sendrecv::VariableMessage request_;
|
||||
sendrecv::VariableMessage reply_;
|
||||
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
|
||||
framework::Scope* scope_;
|
||||
};
|
||||
|
||||
void AsyncGRPCServer::RunSyncUpdate() {
|
||||
grpc::ServerBuilder builder;
|
||||
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service_);
|
||||
|
||||
cq_send_ = builder.AddCompletionQueue();
|
||||
cq_get_ = builder.AddCompletionQueue();
|
||||
server_ = builder.BuildAndStart();
|
||||
LOG(INFO) << "Server listening on " << address_ << std::endl;
|
||||
|
||||
std::function<void()> send_register =
|
||||
std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this);
|
||||
std::function<void()> get_register =
|
||||
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);
|
||||
|
||||
t_send_.reset(
|
||||
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false,
|
||||
cq_send_.get(), "cq_send", send_register)));
|
||||
|
||||
t_get_.reset(
|
||||
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true,
|
||||
cq_get_.get(), "cq_get", get_register)));
|
||||
|
||||
// wait server
|
||||
server_->Wait();
|
||||
t_send_->join();
|
||||
t_get_->join();
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::ShutdownQueue() {
|
||||
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||
cq_send_->Shutdown();
|
||||
cq_get_->Shutdown();
|
||||
is_shut_down_ = true;
|
||||
}
|
||||
|
||||
// This URL explains why shutdown is complicate:
|
||||
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
|
||||
void AsyncGRPCServer::ShutDown() {
|
||||
server_->Shutdown();
|
||||
ShutdownQueue();
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::TryToRegisterNewSendOne() {
|
||||
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||
if (is_shut_down_) {
|
||||
return;
|
||||
}
|
||||
RequestSend* send =
|
||||
new RequestSend(&service_, cq_send_.get(), &var_recv_queue_);
|
||||
VLOG(4) << "create RequestSend status:" << send->Status();
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::TryToRegisterNewGetOne() {
|
||||
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||
if (is_shut_down_) {
|
||||
return;
|
||||
}
|
||||
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_);
|
||||
VLOG(4) << "create Requestget status:" << get->Status();
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) {
|
||||
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||
if (is_shut_down_) {
|
||||
delete last;
|
||||
last = NULL;
|
||||
return;
|
||||
}
|
||||
|
||||
last->SetStatus(FINISH);
|
||||
return;
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
|
||||
std::string cq_name,
|
||||
std::function<void()> TryToRegisterNewOne) {
|
||||
TryToRegisterNewOne();
|
||||
|
||||
void* tag = NULL;
|
||||
bool ok = false;
|
||||
while (true) {
|
||||
if (!cq->Next(&tag, &ok)) {
|
||||
LOG(INFO) << cq_name << " get CompletionQueue shutdown!";
|
||||
break;
|
||||
}
|
||||
|
||||
if (wait && !done_) {
|
||||
Wait();
|
||||
}
|
||||
|
||||
RequestBase* base = (RequestBase*)tag;
|
||||
if (!ok) {
|
||||
VLOG(4) << cq_name << " recv no regular event";
|
||||
TryToRegisterNewOne();
|
||||
delete base;
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (base->Status()) {
|
||||
case PROCESS: {
|
||||
VLOG(4) << cq_name << " status:" << base->Status();
|
||||
TryToRegisterNewOne();
|
||||
base->Process();
|
||||
SetFinishOrDelete(base);
|
||||
break;
|
||||
}
|
||||
case FINISH: {
|
||||
VLOG(4) << cq_name << " status:" << base->Status();
|
||||
delete base;
|
||||
break;
|
||||
}
|
||||
default: { assert(false); }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::Wait() {
|
||||
std::unique_lock<std::mutex> lock(this->mutex_);
|
||||
condition_.wait(lock, [=] { return this->done_ == true; });
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::Reset() {
|
||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||
done_ = false;
|
||||
}
|
||||
|
||||
void AsyncGRPCServer::Done() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||
done_ = true;
|
||||
}
|
||||
condition_.notify_all();
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,91 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/scope.h"
|
||||
#include "paddle/framework/selected_rows.h"
|
||||
#include "paddle/framework/var_type.h"
|
||||
#include "paddle/operators/detail/simple_block_queue.h"
|
||||
|
||||
#include "paddle/operators/detail/send_recv.grpc.pb.h"
|
||||
#include "paddle/operators/detail/send_recv.pb.h"
|
||||
|
||||
#include <grpc++/grpc++.h>
|
||||
#include <grpc/support/log.h>
|
||||
#include <thread>
|
||||
#include "paddle/operators/detail/sendrecvop_utils.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
|
||||
class RequestBase;
|
||||
|
||||
class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
|
||||
public:
|
||||
explicit AsyncGRPCServer(std::string address) { address_ = address; }
|
||||
|
||||
void RunSyncUpdate();
|
||||
|
||||
void Reset();
|
||||
|
||||
void Done();
|
||||
|
||||
void SetScope(framework::Scope *scope) { scope_ = scope; }
|
||||
|
||||
const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
|
||||
|
||||
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
|
||||
|
||||
void ShutDown();
|
||||
|
||||
protected:
|
||||
void Wait();
|
||||
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
|
||||
std::string cq_name,
|
||||
std::function<void()> TryToRegisterNewOne);
|
||||
void TryToRegisterNewSendOne();
|
||||
void TryToRegisterNewGetOne();
|
||||
void SetFinishOrDelete(RequestBase *&last);
|
||||
void ShutdownQueue();
|
||||
|
||||
private:
|
||||
std::mutex cq_mutex_;
|
||||
volatile bool is_shut_down_ = false;
|
||||
std::unique_ptr<grpc::ServerCompletionQueue> cq_send_;
|
||||
std::unique_ptr<grpc::ServerCompletionQueue> cq_get_;
|
||||
|
||||
sendrecv::SendRecvService::AsyncService service_;
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
|
||||
std::string address_;
|
||||
framework::Scope *scope_;
|
||||
// received variable from RPC, operators fetch variable from this queue.
|
||||
SimpleBlockQueue<MessageWithName> var_recv_queue_;
|
||||
|
||||
// condition of the sub program
|
||||
std::mutex mutex_;
|
||||
volatile mutable bool done_;
|
||||
std::condition_variable condition_;
|
||||
|
||||
std::unique_ptr<std::thread> t_send_;
|
||||
std::unique_ptr<std::thread> t_get_;
|
||||
};
|
||||
|
||||
}; // namespace detail
|
||||
}; // namespace operators
|
||||
}; // namespace paddle
|
@ -1,65 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "send_recv_impl.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
Status SendRecvServerImpl::SendVariable(ServerContext *context,
|
||||
const VariableMessage *in_var,
|
||||
VoidMessage *out_var) {
|
||||
MessageWithName msg_with_name =
|
||||
std::make_pair(in_var->varname(), std::move(*in_var));
|
||||
var_recv_queue_.Push(std::move(msg_with_name));
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SendRecvServerImpl::GetVariable(ServerContext *context,
|
||||
const VariableMessage *in_var,
|
||||
VariableMessage *out_var) {
|
||||
std::string get_var_name = in_var->varname();
|
||||
auto *var = scope_->FindVar(get_var_name);
|
||||
|
||||
SerializeToMessage(get_var_name, var, platform::CPUDeviceContext(), out_var);
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SendRecvServerImpl::Wait(ServerContext *context,
|
||||
const VoidMessage *in_var,
|
||||
VoidMessage *out_var) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->mutex_);
|
||||
condition_.wait(lock, [=] { return this->done_ == true; });
|
||||
}
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
void SendRecvServerImpl::Reset() {
|
||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||
done_ = false;
|
||||
}
|
||||
|
||||
void SendRecvServerImpl::Done() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||
done_ = true;
|
||||
}
|
||||
condition_.notify_all();
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,67 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "send_recv_impl.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
bool RPCClient::SendVariable(const framework::Scope& scope,
|
||||
const std::string& inname) {
|
||||
ClientContext context;
|
||||
VariableMessage msg;
|
||||
VoidMessage out_msg;
|
||||
// FIXME(typhoonzero): pass device context to here.
|
||||
auto ctx = platform::CPUDeviceContext();
|
||||
auto* var = scope.FindVar(inname);
|
||||
PADDLE_ENFORCE(var);
|
||||
SerializeToMessage(inname, var, ctx, &msg);
|
||||
|
||||
Status status = stub_->SendVariable(&context, msg, &out_msg);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "gRPC error: " << status.error_message();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RPCClient::GetVariable(const framework::Scope& scope,
|
||||
const std::string& outname) {
|
||||
ClientContext context;
|
||||
VariableMessage call_msg, ret_msg;
|
||||
call_msg.set_varname(outname);
|
||||
auto ctx = platform::CPUDeviceContext();
|
||||
Status status = stub_->GetVariable(&context, call_msg, &ret_msg);
|
||||
auto* outvar = scope.FindVar(outname);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "gRPC error: " << status.error_message();
|
||||
return false;
|
||||
}
|
||||
|
||||
std::istringstream iss(ret_msg.serialized());
|
||||
DeserializeFromMessage(ret_msg, ctx, outvar);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void RPCClient::Wait() {
|
||||
ClientContext context;
|
||||
VoidMessage call_msg, ret_msg;
|
||||
stub_->Wait(&context, call_msg, &ret_msg);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,141 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/scope.h"
|
||||
#include "paddle/framework/selected_rows.h"
|
||||
#include "paddle/framework/var_type.h"
|
||||
#include "paddle/operators/detail/simple_block_queue.h"
|
||||
|
||||
#include "paddle/operators/detail/send_recv.grpc.pb.h"
|
||||
#include "paddle/operators/detail/send_recv.pb.h"
|
||||
|
||||
#include <grpc++/grpc++.h>
|
||||
|
||||
using grpc::Channel;
|
||||
using grpc::Server;
|
||||
using grpc::ServerContext;
|
||||
using grpc::ServerReader;
|
||||
using grpc::ServerBuilder;
|
||||
|
||||
using grpc::ClientContext;
|
||||
using grpc::ClientReader;
|
||||
using grpc::ClientReaderWriter;
|
||||
using grpc::ClientWriter;
|
||||
using grpc::Status;
|
||||
using sendrecv::SendRecvService;
|
||||
using sendrecv::VariableMessage;
|
||||
using sendrecv::VoidMessage;
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
|
||||
|
||||
class SendRecvServerImpl final : public SendRecvService::Service {
|
||||
public:
|
||||
explicit SendRecvServerImpl() {}
|
||||
|
||||
Status SendVariable(ServerContext *context, const VariableMessage *in_var,
|
||||
VoidMessage *out_var) override;
|
||||
Status GetVariable(ServerContext *context, const VariableMessage *in_var,
|
||||
VariableMessage *out_var) override;
|
||||
Status Wait(ServerContext *context, const VoidMessage *in_var,
|
||||
VoidMessage *out_var) override;
|
||||
void Reset();
|
||||
void Done();
|
||||
void SetScope(framework::Scope *scope) { scope_ = scope; };
|
||||
|
||||
const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
|
||||
|
||||
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
|
||||
|
||||
private:
|
||||
// received variable from RPC, operators fetch variable from this queue.
|
||||
SimpleBlockQueue<MessageWithName> var_recv_queue_;
|
||||
framework::Scope *scope_;
|
||||
// condition of the sub program
|
||||
std::mutex mutex_;
|
||||
bool done_;
|
||||
std::condition_variable condition_;
|
||||
};
|
||||
|
||||
// RPCClient is a class to send tensors to pserver sub-network
|
||||
// using different hashing methods.
|
||||
class RPCClient {
|
||||
public:
|
||||
RPCClient(std::shared_ptr<Channel> channel)
|
||||
: stub_(SendRecvService::NewStub(channel)) {}
|
||||
|
||||
bool SendVariable(const framework::Scope &scope, const std::string &inname);
|
||||
bool GetVariable(const framework::Scope &scope, const std::string &outname);
|
||||
void Wait();
|
||||
|
||||
private:
|
||||
std::unique_ptr<SendRecvService::Stub> stub_;
|
||||
};
|
||||
|
||||
inline void SerializeToMessage(const std::string &name,
|
||||
const framework::Variable *var,
|
||||
const platform::DeviceContext &ctx,
|
||||
VariableMessage *msg) {
|
||||
msg->set_varname(name);
|
||||
std::ostringstream oss;
|
||||
switch (framework::ToVarType(var->Type())) {
|
||||
case framework::proto::VarDesc_VarType_LOD_TENSOR:
|
||||
msg->set_type(sendrecv::VarType::LOD_TENSOR);
|
||||
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
|
||||
break;
|
||||
case framework::proto::VarDesc_VarType_SELECTED_ROWS:
|
||||
msg->set_type(sendrecv::VarType::SELECTED_ROWS);
|
||||
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
|
||||
ctx);
|
||||
break;
|
||||
default: {
|
||||
PADDLE_THROW("Serialize does not support type: %s",
|
||||
typeid(var->Type()).name());
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg->set_serialized(oss.str());
|
||||
}
|
||||
|
||||
inline void DeserializeFromMessage(const VariableMessage &msg,
|
||||
const platform::DeviceContext &ctx,
|
||||
framework::Variable *var) {
|
||||
using namespace paddle::framework::proto;
|
||||
std::istringstream iss(msg.serialized());
|
||||
switch (msg.type()) {
|
||||
case sendrecv::VarType::LOD_TENSOR:
|
||||
DeserializeFromStream(iss, var->GetMutable<framework::LoDTensor>(), ctx);
|
||||
break;
|
||||
case sendrecv::VarType::SELECTED_ROWS: {
|
||||
DeserializeFromStream(iss, var->GetMutable<framework::SelectedRows>(),
|
||||
ctx);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PADDLE_THROW("Deserialize does not support type: %s",
|
||||
typeid(var->Type()).name());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,68 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/detail/sendrecvop_utils.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
void SerializeToMessage(const std::string& name, const framework::Variable* var,
|
||||
const platform::DeviceContext& ctx,
|
||||
sendrecv::VariableMessage* msg) {
|
||||
msg->set_varname(name);
|
||||
std::ostringstream oss;
|
||||
switch (framework::ToVarType(var->Type())) {
|
||||
case framework::proto::VarDesc_VarType_LOD_TENSOR:
|
||||
msg->set_type(sendrecv::VarType::LOD_TENSOR);
|
||||
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
|
||||
break;
|
||||
case framework::proto::VarDesc_VarType_SELECTED_ROWS:
|
||||
msg->set_type(sendrecv::VarType::SELECTED_ROWS);
|
||||
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
|
||||
ctx);
|
||||
break;
|
||||
default: {
|
||||
PADDLE_THROW("Serialize does not support type: %s",
|
||||
typeid(var->Type()).name());
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg->set_serialized(oss.str());
|
||||
}
|
||||
|
||||
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
|
||||
const platform::DeviceContext& ctx,
|
||||
framework::Variable* var) {
|
||||
std::istringstream iss(msg.serialized());
|
||||
switch (msg.type()) {
|
||||
case sendrecv::VarType::LOD_TENSOR:
|
||||
DeserializeFromStream(iss, var->GetMutable<framework::LoDTensor>(), ctx);
|
||||
break;
|
||||
case sendrecv::VarType::SELECTED_ROWS: {
|
||||
DeserializeFromStream(iss, var->GetMutable<framework::SelectedRows>(),
|
||||
ctx);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PADDLE_THROW("Deserialize does not support type: %s",
|
||||
typeid(var->Type()).name());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,42 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/framework/data_type.h"
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/scope.h"
|
||||
#include "paddle/framework/selected_rows.h"
|
||||
#include "paddle/framework/var_type.h"
|
||||
|
||||
#include "paddle/operators/detail/send_recv.grpc.pb.h"
|
||||
#include "paddle/operators/detail/send_recv.pb.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
void SerializeToMessage(const std::string& name, const framework::Variable* var,
|
||||
const platform::DeviceContext& ctx,
|
||||
sendrecv::VariableMessage* msg);
|
||||
|
||||
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
|
||||
const platform::DeviceContext& ctx,
|
||||
framework::Variable* var);
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue