move dist codes from operaotrs/detail to operators/distributed

revert-11610-move_hooks
Yancey1989 8 years ago
parent 4b7ae1452b
commit 1ef6cdb60e

@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/distributed/grpc_client.h"
#endif #endif
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
@ -49,8 +49,8 @@ Executor::Executor(const platform::Place& place) : place_(place) {}
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
void Executor::Complete() { void Executor::Complete() {
::paddle::operators::detail::RPCClient::GetInstance< ::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::detail::GRPCClient>() ::paddle::operators::distributed::GRPCClient>()
->SendComplete(); ->SendComplete();
} }
#endif #endif

@ -184,8 +184,8 @@ else()
set(DEPS_OPS ${DEPS_OPS} nccl_op) set(DEPS_OPS ${DEPS_OPS} nccl_op)
endif() endif()
add_subdirectory(detail)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
add_subdirectory(distributed)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
if(WITH_GRPC) if(WITH_GRPC)

@ -15,13 +15,13 @@
#pragma once #pragma once
#ifdef PADDLE_WITH_GRPC #ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/distributed/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/distributed/grpc_server.h"
#define RPCSERVER_T detail::AsyncGRPCServer #define RPCSERVER_T distributed::AsyncGRPCServer
#define RPCCLIENT_T detail::GRPCClient #define RPCCLIENT_T distributed::GRPCClient
#else #else
#include "paddle/fluid/operators/detail/brpc_client.h" #include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/operators/detail/brpc_server.h" #include "paddle/fluid/operators/distributed/brpc_server.h"
#define RPCSERVER_T detail::AsyncBRPCServer #define RPCSERVER_T distributed::AsyncBRPCServer
#define RPCCLIENT_T detail::BRPCClient #define RPCCLIENT_T distributed::BRPCClient
#endif #endif

@ -1,8 +1,3 @@
if(NOT WITH_DISTRIBUTE)
return()
endif()
if(WITH_GRPC) if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor

@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/detail/brpc_client.h" #include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
DEFINE_int32(brpc_channel_num, 24, DEFINE_int32(brpc_channel_num, 24,
"Number of channels to send requests connected to one server"); "Number of channels to send requests connected to one server");
@ -175,6 +175,6 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
return q; return q;
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -31,13 +31,13 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
struct ChannelContext { struct ChannelContext {
brpc::Channel channel; brpc::Channel channel;
@ -95,6 +95,6 @@ class BRPCClient : public RPCClient {
DISABLE_COPY_AND_ASSIGN(BRPCClient); DISABLE_COPY_AND_ASSIGN(BRPCClient);
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/detail/brpc_server.h" #include "paddle/fluid/operators/distributed/brpc_server.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
namespace sendrecv { namespace sendrecv {
typedef std::unordered_map<std::string, typedef std::unordered_map<std::string,
paddle::operators::detail::RequestHandler*> paddle::operators::distributed::RequestHandler*>
HandlerMap; HandlerMap;
class BRPCServiceImpl : public SendRecvService { class BRPCServiceImpl : public SendRecvService {
@ -27,17 +27,17 @@ class BRPCServiceImpl : public SendRecvService {
: request_send_h_(nullptr), : request_send_h_(nullptr),
request_get_h_(nullptr), request_get_h_(nullptr),
request_prefetch_h_(nullptr) { request_prefetch_h_(nullptr) {
auto it = rpc_call_map.find(paddle::operators::detail::kRequestSend); auto it = rpc_call_map.find(paddle::operators::distributed::kRequestSend);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_send_h_ = it->second; request_send_h_ = it->second;
} }
it = rpc_call_map.find(paddle::operators::detail::kRequestSend); it = rpc_call_map.find(paddle::operators::distributed::kRequestSend);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_get_h_ = it->second; request_get_h_ = it->second;
} }
it = rpc_call_map.find(paddle::operators::detail::kRequestPrefetch); it = rpc_call_map.find(paddle::operators::distributed::kRequestPrefetch);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_prefetch_h_ = it->second; request_prefetch_h_ = it->second;
} }
@ -88,15 +88,15 @@ class BRPCServiceImpl : public SendRecvService {
} }
private: private:
paddle::operators::detail::RequestHandler* request_send_h_; paddle::operators::distributed::RequestHandler* request_send_h_;
paddle::operators::detail::RequestHandler* request_get_h_; paddle::operators::distributed::RequestHandler* request_get_h_;
paddle::operators::detail::RequestHandler* request_prefetch_h_; paddle::operators::distributed::RequestHandler* request_prefetch_h_;
}; };
} // namespace sendrecv } // namespace sendrecv
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
void AsyncBRPCServer::StartServer() { void AsyncBRPCServer::StartServer() {
// Instance of your service. // Instance of your service.
@ -139,6 +139,6 @@ void AsyncBRPCServer::WaitServerReady() {
VLOG(3) << "AsyncGRPCServer WaitSeverReady"; VLOG(3) << "AsyncGRPCServer WaitSeverReady";
} }
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle

@ -19,12 +19,12 @@ limitations under the License. */
#include <string> #include <string>
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class AsyncBRPCServer final : public RPCServer { class AsyncBRPCServer final : public RPCServer {
public: public:
@ -48,6 +48,6 @@ class AsyncBRPCServer final : public RPCServer {
int ready_; int ready_;
}; };
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle

@ -17,11 +17,11 @@ limitations under the License. */
// file and did some modifications so that we can send gRPC // file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data. // requests without too much copying of the tensor data.
#include "paddle/fluid/operators/detail/bytebuffer_stream.h" #include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
GrpcByteBufferSource::GrpcByteBufferSource() {} GrpcByteBufferSource::GrpcByteBufferSource() {}
@ -83,6 +83,6 @@ google::protobuf::int64 GrpcByteBufferSource::ByteCount() const {
return byte_count_; return byte_count_;
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -106,7 +106,7 @@ class GrpcBufferReader final
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
// Source provides a way for a particular RPC implementation to provide // Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom. // received data to ParseFrom.
class Source { class Source {
@ -183,6 +183,6 @@ class GrpcByteSource : public Source {
char space_[sizeof(Reader)]; char space_[sizeof(Reader)];
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -12,19 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/distributed/grpc_client.h"
#include <sys/time.h> #include <sys/time.h>
#include <limits> #include <limits>
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
void GRPCClient::InitImpl() { InitEventLoop(); } void GRPCClient::InitImpl() { InitEventLoop(); }
@ -276,6 +276,6 @@ std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
return ch; return ch;
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -38,13 +38,13 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
struct VarHandle { struct VarHandle {
std::string ep; std::string ep;
@ -226,6 +226,6 @@ class GRPCClient : public RPCClient {
DISABLE_COPY_AND_ASSIGN(GRPCClient); DISABLE_COPY_AND_ASSIGN(GRPCClient);
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -21,8 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
@ -50,7 +50,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < 564; ++i) rows->push_back(i); for (int i = 0; i < 564; ++i) rows->push_back(i);
::grpc::ByteBuffer msg; ::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), static_cast<size_t>(0)); EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize // deserialize
@ -81,10 +81,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// deserialize zero-copy // deserialize zero-copy
// framework::Variable var2; // framework::Variable var2;
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); // operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2);
framework::Scope scope; framework::Scope scope;
scope.Var("myvar"); scope.Var("myvar");
operators::detail::VariableResponse resp(&scope, &ctx); operators::distributed::VariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(msg), 0); EXPECT_EQ(resp.Parse(msg), 0);
framework::Variable* var2 = resp.GetVar(); framework::Variable* var2 = resp.GetVar();
@ -128,7 +128,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
math::set_constant(ctx, tensor, 31.9); math::set_constant(ctx, tensor, 31.9);
::grpc::ByteBuffer msg; ::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); operators::distributed::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), static_cast<size_t>(0)); EXPECT_GT(msg.Length(), static_cast<size_t>(0));
// deserialize // deserialize
@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// deserialize zero-copy // deserialize zero-copy
framework::Scope scope; framework::Scope scope;
scope.Var("myvar"); scope.Var("myvar");
operators::detail::VariableResponse resp(&scope, &ctx); operators::distributed::VariableResponse resp(&scope, &ctx);
if (from_type == 0) { if (from_type == 0) {
EXPECT_EQ(resp.Parse(msg), 0); EXPECT_EQ(resp.Parse(msg), 0);
} else { } else {

@ -15,13 +15,13 @@ limitations under the License. */
#include <limits> #include <limits>
#include <string> #include <string>
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/distributed/grpc_server.h"
using ::grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
enum CallStatus { PROCESS = 0, FINISH }; enum CallStatus { PROCESS = 0, FINISH };
// reference: // reference:
@ -74,7 +74,7 @@ class RequestSend final : public RequestBase {
request_.reset(new VariableResponse(request_handler->scope(), request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx(), request_handler->dev_ctx(),
!request_handler->sync_mode())); !request_handler->sync_mode()));
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable); int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
@ -106,7 +106,7 @@ class RequestGet final : public RequestBase {
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable); auto method_id = static_cast<int>(distributed::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_, method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
@ -150,7 +150,8 @@ class RequestPrefetch final : public RequestBase {
local_scope_(nullptr) { local_scope_(nullptr) {
request_.reset(new VariableResponse(request_handler->scope(), request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true)); request_handler->dev_ctx(), true));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id =
static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
@ -354,6 +355,6 @@ void AsyncGRPCServer::HandleRequest(
} }
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -29,17 +29,17 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h" #include "paddle/fluid/operators/distributed/grpc_service.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class RequestBase; class RequestBase;
@ -84,6 +84,6 @@ class AsyncGRPCServer final : public RPCServer {
std::map<std::string, std::vector<RequestBase*>> rpc_reqs_; std::map<std::string, std::vector<RequestBase*>> rpc_reqs_;
}; };
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle

@ -23,7 +23,7 @@
#include <grpc++/impl/codegen/stub_options.h> #include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h> #include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h> #include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
@ -42,24 +42,25 @@ class ServerContext;
// Support parsing/unparsing of tensorflow::VariableResponse. // Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse. // Wire-format is identical to RecvVariableResponse.
template <> template <>
class SerializationTraits<paddle::operators::detail::VariableResponse> { class SerializationTraits<paddle::operators::distributed::VariableResponse> {
public: public:
static Status Serialize( static Status Serialize(
const paddle::operators::detail::VariableResponse& msg, const paddle::operators::distributed::VariableResponse& msg,
grpc_byte_buffer** bp, bool* own_buffer) { grpc_byte_buffer** bp, bool* own_buffer) {
PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!"); PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!");
return Status(); return Status();
} }
static Status Deserialize(grpc_byte_buffer* buffer, static Status Deserialize(
paddle::operators::detail::VariableResponse* msg, grpc_byte_buffer* buffer,
int max_message_size = INT_MAX) { paddle::operators::distributed::VariableResponse* msg,
int max_message_size = INT_MAX) {
if (buffer == nullptr) { if (buffer == nullptr) {
return Status(StatusCode::INTERNAL, "No payload"); return Status(StatusCode::INTERNAL, "No payload");
} }
Status result = g_core_codegen_interface->ok(); Status result = g_core_codegen_interface->ok();
if (result.ok()) { if (result.ok()) {
paddle::operators::detail::GrpcByteSource source(buffer); paddle::operators::distributed::GrpcByteSource source(buffer);
int ret = msg->Parse(&source); int ret = msg->Parse(&source);
if (ret != 0) { if (ret != 0) {
result = Status(StatusCode::INTERNAL, "VariableResponse parse error"); result = Status(StatusCode::INTERNAL, "VariableResponse parse error");
@ -73,7 +74,7 @@ class SerializationTraits<paddle::operators::detail::VariableResponse> {
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
enum class GrpcMethod { enum class GrpcMethod {
kSendVariable, kSendVariable,
@ -118,6 +119,6 @@ class GrpcService final {
}; };
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -26,7 +26,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
char* EncodeVarint32(char* dst, uint32_t v) { char* EncodeVarint32(char* dst, uint32_t v) {
// Operate on characters as unsigneds // Operate on characters as unsigneds
@ -144,6 +144,6 @@ class ProtoEncodeHelper {
char* limit_; // Just for CHECKs char* limit_; // Just for CHECKs
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -31,7 +31,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
constexpr char kRequestSend[] = "RequestSend"; constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestGet[] = "RequestGet";
@ -124,6 +124,6 @@ class RequestHandler {
RPCServer* rpc_server_; RPCServer* rpc_server_;
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -20,12 +20,12 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
bool RequestSendHandler::Handle(const std::string& varname, bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
@ -119,6 +119,6 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return true; return true;
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -28,11 +28,11 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class RequestSendHandler final : public RequestHandler { class RequestSendHandler final : public RequestHandler {
public: public:
@ -66,6 +66,6 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
std::once_flag RPCClient::init_flag_; std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr); std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -22,7 +22,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class RPCClient { class RPCClient {
public: public:
@ -84,6 +84,6 @@ class RPCClient {
static std::once_flag init_flag_; static std::once_flag init_flag_;
static std::unique_ptr<RPCClient> rpc_client_; static std::unique_ptr<RPCClient> rpc_client_;
}; };
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -17,11 +17,11 @@
#include <limits> #include <limits>
#include <string> #include <string>
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
void RPCServer::ShutDown() { void RPCServer::ShutDown() {
LOG(INFO) << "RPCServer ShutDown "; LOG(INFO) << "RPCServer ShutDown ";
@ -112,6 +112,6 @@ void RPCServer::WaitCond(const std::string& rpc_name) {
lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); }); lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
} }
} // namespace detail } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -19,11 +19,11 @@
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace distributed {
class RPCServer { class RPCServer {
public: public:
@ -86,6 +86,6 @@ class RPCServer {
friend class RequestHandler; friend class RequestHandler;
}; };
}; // namespace detail }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle

@ -22,18 +22,18 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace detail = paddle::operators::detail; namespace distributed = paddle::operators::distributed;
USE_OP(lookup_table); USE_OP(lookup_table);
std::unique_ptr<detail::RPCServer> g_rpc_service; std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler; std::unique_ptr<distributed::RequestHandler> g_req_handler;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0); auto root_block = program->MutableBlock(0);
@ -113,19 +113,21 @@ void StartServer() {
g_req_handler->SetScope(&scope); g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe); g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(detail::kRequestPrefetch, g_req_handler.get()); g_rpc_service->RegisterRPC(distributed::kRequestPrefetch,
g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get()); g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread( std::thread server_thread(
std::bind(&detail::RPCServer::StartServer, g_rpc_service.get())); std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join(); server_thread.join();
} }
TEST(PREFETCH, CPU) { TEST(PREFETCH, CPU) {
g_req_handler.reset(new detail::RequestPrefetchHandler(true)); g_req_handler.reset(new distributed::RequestPrefetchHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save