Add brpc surpport. (#11263)
parent
bfa3fd6f15
commit
d9de6b8621
@ -0,0 +1,58 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
INCLUDE(ExternalProject)
|
||||
|
||||
SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc)
|
||||
SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/brpc)
|
||||
SET(BRPC_INCLUDE_DIR "${BRPC_INSTALL_DIR}/include" CACHE PATH "brpc include directory." FORCE)
|
||||
SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc library." FORCE)
|
||||
|
||||
INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
|
||||
|
||||
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
|
||||
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf")
|
||||
|
||||
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
|
||||
ExternalProject_Add(
|
||||
extern_brpc
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
GIT_REPOSITORY "https://github.com/brpc/brpc"
|
||||
GIT_TAG "6d153dd7ff00f960ae6895c9c5fff0ce9f07aff2"
|
||||
PREFIX ${BRPC_SOURCES_DIR}
|
||||
UPDATE_COMMAND ""
|
||||
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
|
||||
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
|
||||
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
|
||||
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
|
||||
-DCMAKE_INSTALL_PREFIX=${BRPC_INSTALL_DIR}
|
||||
-DCMAKE_INSTALL_LIBDIR=${BRPC_INSTALL_DIR}/lib
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
|
||||
-DCMAKE_PREFIX_PATH=${prefix_path}
|
||||
-DBRPC_WITH_GLOG=ON
|
||||
${EXTERNAL_OPTIONAL_ARGS}
|
||||
LIST_SEPARATOR |
|
||||
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${BRPC_INSTALL_DIR}
|
||||
-DCMAKE_INSTALL_LIBDIR:PATH=${BRPC_INSTALL_DIR}/lib
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
|
||||
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
|
||||
)
|
||||
ADD_DEPENDENCIES(extern_brpc protobuf leveldb gflags glog gtest snappy)
|
||||
ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES})
|
||||
ADD_DEPENDENCIES(brpc extern_brpc)
|
||||
|
||||
|
||||
LIST(APPEND external_project_dependencies brpc)
|
@ -0,0 +1,44 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
INCLUDE(ExternalProject)
|
||||
|
||||
SET(LEVELDB_SOURCES_DIR ${THIRD_PARTY_PATH}/leveldb)
|
||||
SET(LEVELDB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/leveldb)
|
||||
SET(LEVELDB_INCLUDE_DIR "${LEVELDB_INSTALL_DIR}/include" CACHE PATH "leveldb include directory." FORCE)
|
||||
SET(LEVELDB_LIBRARIES "${LEVELDB_INSTALL_DIR}/lib/libleveldb.a" CACHE FILEPATH "leveldb library." FORCE)
|
||||
INCLUDE_DIRECTORIES(${LEVELDB_INCLUDE_DIR})
|
||||
|
||||
ExternalProject_Add(
|
||||
extern_leveldb
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
PREFIX ${LEVELDB_SOURCES_DIR}
|
||||
URL "https://github.com/google/leveldb/archive/v1.18.tar.gz"
|
||||
URL_MD5 "73770de34a2a5ab34498d2e05b2b7fa0"
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a
|
||||
INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/
|
||||
&& cp ${LEVELDB_SOURCES_DIR}/src/extern_leveldb/libleveldb.a ${LEVELDB_LIBRARIES}
|
||||
&& cp -r ${LEVELDB_SOURCES_DIR}/src/extern_leveldb/include ${LEVELDB_INSTALL_DIR}/
|
||||
BUILD_IN_SOURCE 1
|
||||
)
|
||||
|
||||
ADD_DEPENDENCIES(extern_leveldb snappy)
|
||||
|
||||
ADD_LIBRARY(leveldb STATIC IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET leveldb PROPERTY IMPORTED_LOCATION ${LEVELDB_LIBRARIES})
|
||||
ADD_DEPENDENCIES(leveldb extern_leveldb)
|
||||
|
||||
LIST(APPEND external_project_dependencies leveldb)
|
||||
|
@ -1,12 +1,38 @@
|
||||
if(WITH_DISTRIBUTE)
|
||||
if(NOT WITH_DISTRIBUTE)
|
||||
return()
|
||||
endif()
|
||||
|
||||
|
||||
if(WITH_GRPC)
|
||||
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
|
||||
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
|
||||
selected_rows memory)
|
||||
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
|
||||
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
|
||||
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
cc_test(serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
|
||||
cares zlib protobuf sendrecvop_grpc SERIAL)
|
||||
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc
|
||||
cc_test(grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc
|
||||
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
|
||||
proto_desc lookup_table_op SERIAL)
|
||||
return()
|
||||
endif()
|
||||
|
||||
|
||||
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
|
||||
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc
|
||||
PROTO send_recv.proto
|
||||
DEPS lod_tensor selected_rows memory)
|
||||
|
||||
find_library(OPENSSL_CRYPTO_LIBRARY_STATIC NAMES libcrypto.so)
|
||||
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY_STATIC})
|
||||
|
||||
|
||||
find_library(OPENSSL_SSL_LIBRARY_STATIC NAMES libssl.so)
|
||||
ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY_STATIC})
|
||||
|
||||
cc_test(brpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_brpc
|
||||
brpc protobuf leveldb gflags glog
|
||||
protobuf executor proto_desc lookup_table_op snappystream snappy ssl crypto SERIAL)
|
||||
|
@ -0,0 +1,180 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/detail/brpc_client.h"
|
||||
#include "paddle/fluid/framework/threadpool.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
DEFINE_int32(brpc_channel_num, 24,
|
||||
"Number of channels to send requests connected to one server");
|
||||
DEFINE_int32(timeout_ms, 30000, "RPC timeout in milliseconds");
|
||||
DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)");
|
||||
|
||||
BRPCClient::~BRPCClient() { Wait(); }
|
||||
|
||||
void HandleSendResponse(brpc::Controller* cntl,
|
||||
sendrecv::VoidMessage* response) {
|
||||
// std::unique_ptr makes sure cntl/response will be deleted before returning.
|
||||
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
|
||||
std::unique_ptr<sendrecv::VoidMessage> response_guard(response);
|
||||
|
||||
if (cntl->Failed()) {
|
||||
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText();
|
||||
return;
|
||||
}
|
||||
LOG(INFO) << "Received response from " << cntl->remote_side()
|
||||
<< " latency=" << cntl->latency_us() << "us";
|
||||
}
|
||||
|
||||
bool BRPCClient::AsyncSendVar(const std::string& ep,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope,
|
||||
const std::string& var_name, int64_t time_out) {
|
||||
const platform::DeviceContext* p_ctx = &ctx;
|
||||
const std::string ep_val = ep;
|
||||
const std::string var_name_val = var_name;
|
||||
const framework::Scope* p_scope = &scope;
|
||||
const auto ch_ptr = GetChannel(ep_val);
|
||||
|
||||
framework::AsyncIO(
|
||||
[var_name_val, p_ctx, ep_val, p_scope, time_out, ch_ptr, this] {
|
||||
auto ch_ctx = ch_ptr->Pop();
|
||||
brpc::Controller* cntl = new brpc::Controller();
|
||||
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
|
||||
cntl->set_timeout_ms(time_out);
|
||||
|
||||
google::protobuf::Closure* done =
|
||||
brpc::NewCallback(&HandleSendResponse, cntl, response);
|
||||
|
||||
sendrecv::VariableMessage request;
|
||||
ch_ctx->stub->SendVariable(cntl, &request, response, done);
|
||||
});
|
||||
req_count_++;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void HandleGetResponse(brpc::Controller* cntl,
|
||||
sendrecv::VariableMessage* response) {
|
||||
// std::unique_ptr makes sure cntl/response will be deleted before returning.
|
||||
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
|
||||
std::unique_ptr<sendrecv::VariableMessage> response_guard(response);
|
||||
|
||||
if (cntl->Failed()) {
|
||||
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText();
|
||||
return;
|
||||
}
|
||||
LOG(INFO) << "Received response from " << cntl->remote_side()
|
||||
<< " latency=" << cntl->latency_us() << "us";
|
||||
|
||||
// framework::Variable* outvar = nullptr;
|
||||
// DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
|
||||
}
|
||||
|
||||
bool BRPCClient::AsyncGetVar(const std::string& ep,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope,
|
||||
const std::string& var_name, int64_t time_out) {
|
||||
const platform::DeviceContext* p_ctx = &ctx;
|
||||
const std::string ep_val = ep;
|
||||
const std::string var_name_val = var_name;
|
||||
const framework::Scope* p_scope = &scope;
|
||||
const auto ch = GetChannel(ep_val);
|
||||
|
||||
framework::AsyncIO(
|
||||
[var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {});
|
||||
|
||||
req_count_++;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BRPCClient::AsyncPrefetchVar(const std::string& ep,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope,
|
||||
const std::string& in_var_name,
|
||||
const std::string& out_var_name,
|
||||
int64_t time_out) {
|
||||
const platform::DeviceContext* p_ctx = &ctx;
|
||||
const std::string ep_val = ep;
|
||||
const std::string in_var_name_val = in_var_name;
|
||||
const std::string out_var_name_val = out_var_name;
|
||||
const framework::Scope* p_scope = &scope;
|
||||
const auto ch = GetChannel(ep_val);
|
||||
|
||||
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
|
||||
time_out, ch, this] {});
|
||||
|
||||
req_count_++;
|
||||
return true;
|
||||
}
|
||||
|
||||
void BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
|
||||
int64_t time_out) {
|
||||
req_count_++;
|
||||
}
|
||||
|
||||
void BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
|
||||
int64_t time_out) {
|
||||
req_count_++;
|
||||
}
|
||||
|
||||
void BRPCClient::Wait() {
|
||||
std::unique_lock<std::mutex> lk(sync_mutex_);
|
||||
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
|
||||
}
|
||||
|
||||
ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(chan_mutex_);
|
||||
auto it = channels_.find(ep);
|
||||
if (it != channels_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
ChannelQueuePtr q(new framework::BlockingQueue<ChannelContextPtr>());
|
||||
|
||||
brpc::ChannelOptions options;
|
||||
options.protocol = "baidu_std";
|
||||
options.connection_type = "pooled";
|
||||
options.connect_timeout_ms = 100;
|
||||
options.timeout_ms = FLAGS_timeout_ms /*milliseconds*/;
|
||||
options.max_retry = FLAGS_max_retry;
|
||||
for (int i = 0; i < FLAGS_brpc_channel_num; ++i) {
|
||||
std::shared_ptr<ChannelContext> c(new ChannelContext());
|
||||
if (c->channel.Init(ep.c_str(), &options) != 0) {
|
||||
LOG(ERROR) << "Fail to initialize channel";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
c->stub.reset(new sendrecv::SendRecvService_Stub(
|
||||
static_cast<google::protobuf::RpcChannel*>(&c->channel)));
|
||||
q->Push(c);
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(chan_mutex_);
|
||||
channels_[ep] = q;
|
||||
}
|
||||
|
||||
return q;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,100 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <time.h>
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <mutex> // NOLINT
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "brpc/channel.h"
|
||||
#include "paddle/fluid/framework/blocking_queue.h"
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/operators/detail/rpc_client.h"
|
||||
#include "paddle/fluid/operators/detail/send_recv.pb.h"
|
||||
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
struct ChannelContext {
|
||||
brpc::Channel channel;
|
||||
std::shared_ptr<sendrecv::SendRecvService_Stub> stub;
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<ChannelContext> ChannelContextPtr;
|
||||
typedef std::shared_ptr<framework::BlockingQueue<ChannelContextPtr>>
|
||||
ChannelQueuePtr;
|
||||
|
||||
class BRPCClient : public RPCClient {
|
||||
public:
|
||||
BRPCClient() {}
|
||||
virtual ~BRPCClient();
|
||||
|
||||
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope, const std::string& var_name,
|
||||
int64_t time_out = RPCClient::rpc_time_out) override;
|
||||
|
||||
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope, const std::string& var_name,
|
||||
int64_t time_out = RPCClient::rpc_time_out) override;
|
||||
|
||||
bool AsyncPrefetchVar(const std::string& ep,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope& scope,
|
||||
const std::string& in_var_name,
|
||||
const std::string& out_var_name,
|
||||
int64_t time_out = RPCClient::rpc_time_out) override;
|
||||
|
||||
void AsyncSendBatchBarrier(
|
||||
const std::string& ep,
|
||||
int64_t time_out = RPCClient::rpc_time_out) override;
|
||||
|
||||
void AsyncSendFetchBarrier(
|
||||
const std::string& ep,
|
||||
int64_t time_out = RPCClient::rpc_time_out) override;
|
||||
|
||||
void Wait() override;
|
||||
|
||||
private:
|
||||
void Proceed();
|
||||
ChannelQueuePtr GetChannel(const std::string& ep);
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, ChannelQueuePtr> channels_;
|
||||
|
||||
// mutex for Wait client sync
|
||||
std::mutex sync_mutex_;
|
||||
std::condition_variable sync_cond_;
|
||||
std::atomic<int64_t> req_count_{0};
|
||||
|
||||
// mutex for GetChannel thread safety
|
||||
std::mutex chan_mutex_;
|
||||
DISABLE_COPY_AND_ASSIGN(BRPCClient);
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,144 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/detail/brpc_server.h"
|
||||
#include "paddle/fluid/operators/detail/request_handler.h"
|
||||
|
||||
namespace sendrecv {
|
||||
|
||||
typedef std::unordered_map<std::string,
|
||||
paddle::operators::detail::RequestHandler*>
|
||||
HandlerMap;
|
||||
|
||||
class BRPCServiceImpl : public SendRecvService {
|
||||
public:
|
||||
explicit BRPCServiceImpl(const HandlerMap& rpc_call_map)
|
||||
: request_send_h_(nullptr),
|
||||
request_get_h_(nullptr),
|
||||
request_prefetch_h_(nullptr) {
|
||||
auto it = rpc_call_map.find(paddle::operators::detail::kRequestSend);
|
||||
if (it != rpc_call_map.end()) {
|
||||
request_send_h_ = it->second;
|
||||
}
|
||||
|
||||
it = rpc_call_map.find(paddle::operators::detail::kRequestSend);
|
||||
if (it != rpc_call_map.end()) {
|
||||
request_get_h_ = it->second;
|
||||
}
|
||||
|
||||
it = rpc_call_map.find(paddle::operators::detail::kRequestPrefetch);
|
||||
if (it != rpc_call_map.end()) {
|
||||
request_prefetch_h_ = it->second;
|
||||
}
|
||||
}
|
||||
|
||||
virtual ~BRPCServiceImpl() {}
|
||||
|
||||
void SendVariable(google::protobuf::RpcController* cntl_butil,
|
||||
const VariableMessage* request, VoidMessage* response,
|
||||
google::protobuf::Closure* done) override {
|
||||
PADDLE_ENFORCE(request_send_h_ != nullptr,
|
||||
"RequestSend handler should be registed first!");
|
||||
brpc::ClosureGuard done_guard(done);
|
||||
|
||||
paddle::framework::Scope* local_scope = request_send_h_->scope();
|
||||
paddle::framework::Variable* outvar = nullptr;
|
||||
paddle::framework::Variable* invar = nullptr;
|
||||
|
||||
std::string varname = request->varname();
|
||||
|
||||
if (!request_send_h_->sync_mode()) {
|
||||
local_scope = &request_send_h_->scope()->NewScope();
|
||||
invar = local_scope->Var(varname);
|
||||
} else {
|
||||
invar = local_scope->FindVar(varname);
|
||||
}
|
||||
|
||||
request_send_h_->Handle(varname, local_scope, invar, &outvar);
|
||||
|
||||
if (!request_send_h_->sync_mode()) {
|
||||
request_send_h_->scope()->DeleteScope(local_scope);
|
||||
}
|
||||
}
|
||||
|
||||
void GetVariable(google::protobuf::RpcController* cntl_butil,
|
||||
const VariableMessage* request, VariableMessage* response,
|
||||
google::protobuf::Closure* done) override {
|
||||
PADDLE_ENFORCE(request_get_h_ != nullptr,
|
||||
"RequestGet handler should be registed first!");
|
||||
}
|
||||
|
||||
void PrefetchVariable(google::protobuf::RpcController* cntl_butil,
|
||||
const VariableMessage* request,
|
||||
VariableMessage* response,
|
||||
google::protobuf::Closure* done) override {
|
||||
PADDLE_ENFORCE(request_prefetch_h_ != nullptr,
|
||||
"kRequestPrefetch handler should be registed first!");
|
||||
}
|
||||
|
||||
private:
|
||||
paddle::operators::detail::RequestHandler* request_send_h_;
|
||||
paddle::operators::detail::RequestHandler* request_get_h_;
|
||||
paddle::operators::detail::RequestHandler* request_prefetch_h_;
|
||||
};
|
||||
} // namespace sendrecv
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
void AsyncBRPCServer::StartServer() {
|
||||
// Instance of your service.
|
||||
sendrecv::BRPCServiceImpl service_impl(rpc_call_map_);
|
||||
|
||||
// Add the service into server. Notice the second parameter, because the
|
||||
// service is put on stack, we don't want server to delete it, otherwise
|
||||
// use brpc::SERVER_OWNS_SERVICE.
|
||||
if (server_.AddService(&service_impl, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
|
||||
LOG(FATAL) << "Fail to add service";
|
||||
return;
|
||||
}
|
||||
|
||||
brpc::ServerOptions options;
|
||||
options.idle_timeout_sec = idle_timeout_s_;
|
||||
options.max_concurrency = max_concurrency_;
|
||||
if (server_.Start(bind_address_.c_str(), &options) != 0) {
|
||||
LOG(FATAL) << "Fail to start EchoServer" << bind_address_;
|
||||
return;
|
||||
}
|
||||
|
||||
butil::EndPoint ep = server_.listen_address();
|
||||
selected_port_ = ep.port;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(this->mutex_ready_);
|
||||
ready_ = 1;
|
||||
}
|
||||
condition_ready_.notify_all();
|
||||
|
||||
server_.Join();
|
||||
}
|
||||
|
||||
void AsyncBRPCServer::ShutDownImpl() { server_.Stop(1000); }
|
||||
|
||||
void AsyncBRPCServer::WaitServerReady() {
|
||||
VLOG(3) << "AsyncGRPCServer is wait server ready";
|
||||
std::unique_lock<std::mutex> lock(this->mutex_ready_);
|
||||
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
|
||||
VLOG(3) << "AsyncGRPCServer WaitSeverReady";
|
||||
}
|
||||
|
||||
}; // namespace detail
|
||||
}; // namespace operators
|
||||
}; // namespace paddle
|
@ -0,0 +1,53 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <condition_variable> // NOLINT
|
||||
#include <mutex> // NOLINT
|
||||
#include <string>
|
||||
|
||||
#include "brpc/server.h"
|
||||
#include "paddle/fluid/operators/detail/rpc_server.h"
|
||||
#include "paddle/fluid/operators/detail/send_recv.pb.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
class AsyncBRPCServer final : public RPCServer {
|
||||
public:
|
||||
explicit AsyncBRPCServer(const std::string& address, int client_num)
|
||||
: RPCServer(address, client_num), ready_(0) {}
|
||||
|
||||
virtual ~AsyncBRPCServer() {}
|
||||
void StartServer() override;
|
||||
void WaitServerReady() override;
|
||||
|
||||
private:
|
||||
void ShutDownImpl() override;
|
||||
|
||||
brpc::Server server_;
|
||||
|
||||
static constexpr int idle_timeout_s_ = -1;
|
||||
static constexpr int max_concurrency_ = 0;
|
||||
|
||||
std::mutex mutex_ready_;
|
||||
std::condition_variable condition_ready_;
|
||||
int ready_;
|
||||
};
|
||||
|
||||
}; // namespace detail
|
||||
}; // namespace operators
|
||||
}; // namespace paddle
|
@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef PADDLE_WITH_GRPC
|
||||
#include "paddle/fluid/operators/detail/grpc_client.h"
|
||||
#include "paddle/fluid/operators/detail/grpc_server.h"
|
||||
#define RPCSERVER_T detail::AsyncGRPCServer
|
||||
#define RPCCLIENT_T detail::GRPCClient
|
||||
#else
|
||||
#include "paddle/fluid/operators/detail/brpc_client.h"
|
||||
#include "paddle/fluid/operators/detail/brpc_server.h"
|
||||
#define RPCSERVER_T detail::AsyncBRPCServer
|
||||
#define RPCCLIENT_T detail::BRPCClient
|
||||
#endif
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue