Change grpc interface to compatible with brpc. (#12164)
parent
b06309381b
commit
3a6213f493
@ -1,33 +1,43 @@
|
||||
if(NOT WITH_DISTRIBUTE)
|
||||
return()
|
||||
endif()
|
||||
|
||||
if(WITH_GRPC)
|
||||
set(cc_generic_services "false")
|
||||
else()
|
||||
set(cc_generic_services "true")
|
||||
endif()
|
||||
configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY)
|
||||
|
||||
if(WITH_GRPC)
|
||||
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
|
||||
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
|
||||
selected_rows memory)
|
||||
grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
|
||||
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
|
||||
PROTO send_recv.proto
|
||||
DEPS lod_tensor selected_rows memory)
|
||||
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
|
||||
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
cc_test(serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
|
||||
cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
|
||||
cc_test(grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc
|
||||
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
|
||||
proto_desc lookup_table_op SERIAL)
|
||||
cc_test(grpc_serde_test SRCS grpc_serde_test.cc
|
||||
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
|
||||
cc_test(grpc_server_test SRCS rpc_server_test.cc
|
||||
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL)
|
||||
return()
|
||||
endif()
|
||||
|
||||
|
||||
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
|
||||
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc
|
||||
|
||||
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
|
||||
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
|
||||
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
|
||||
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
|
||||
PROTO send_recv.proto
|
||||
DEPS lod_tensor selected_rows memory)
|
||||
|
||||
find_library(OPENSSL_CRYPTO_LIBRARY_STATIC NAMES libcrypto.so)
|
||||
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY_STATIC})
|
||||
|
||||
set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
|
||||
|
||||
find_library(OPENSSL_SSL_LIBRARY_STATIC NAMES libssl.so)
|
||||
ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY_STATIC})
|
||||
cc_test(brpc_server_test SRCS rpc_server_test.cc
|
||||
DEPS ${brpc_test_depends} SERIAL)
|
||||
|
||||
cc_test(brpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_brpc
|
||||
brpc protobuf leveldb gflags glog
|
||||
protobuf executor proto_desc lookup_table_op snappystream snappy ssl crypto SERIAL)
|
||||
cc_test(brpc_serde_test SRCS brpc_serde_test.cc
|
||||
DEPS ${brpc_test_depends} SERIAL)
|
||||
|
@ -0,0 +1,157 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#include <nccl.h>
|
||||
#endif
|
||||
#include <sys/time.h>
|
||||
#include <thread> // NOLINT
|
||||
|
||||
#include "google/protobuf/io/coded_stream.h"
|
||||
#include "google/protobuf/io/zero_copy_stream.h"
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
|
||||
#include "paddle/fluid/operators/distributed/grpc_serde.h"
|
||||
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
|
||||
#include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
|
||||
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
|
||||
#include "paddle/fluid/platform/profiler.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace distributed {
|
||||
|
||||
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
|
||||
const platform::DeviceContext& ctx,
|
||||
::grpc::ByteBuffer* msg,
|
||||
const std::string& out_name) {
|
||||
// Default DestroyCallback does nothing, When using GPU
|
||||
// the CPU buffer need to be freed.
|
||||
DestroyCallback destroy_callback = [](void* backing) {};
|
||||
VarMsg request;
|
||||
void* payload = nullptr;
|
||||
size_t payload_size;
|
||||
|
||||
request.set_varname(name);
|
||||
// Note: normally the profiler is enabled in 1 trainer, hence only
|
||||
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
|
||||
// servers the trainer's profiling state so that PS can follow the
|
||||
// trainer.
|
||||
if (platform::ShouldSendProfileState()) {
|
||||
if (platform::IsProfileEnabled()) {
|
||||
request.set_profile(platform::kEnableProfiler);
|
||||
} else {
|
||||
request.set_profile(platform::kDisableProfiler);
|
||||
}
|
||||
}
|
||||
if (!out_name.empty()) {
|
||||
request.set_out_varname(out_name);
|
||||
}
|
||||
if (var->IsType<framework::LoDTensor>()) {
|
||||
request.set_type(::sendrecv::LOD_TENSOR);
|
||||
GetTensorPayload(var, ctx, &request, &payload, &payload_size);
|
||||
} else if (var->IsType<framework::SelectedRows>()) {
|
||||
request.set_type(::sendrecv::SELECTED_ROWS);
|
||||
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
} else if (var->IsType<ncclUniqueId>()) {
|
||||
request.set_type(::sendrecv::NCCL_ID);
|
||||
#endif
|
||||
} else {
|
||||
PADDLE_THROW("Serialize does not support type: %s",
|
||||
typeid(var->Type()).name());
|
||||
}
|
||||
|
||||
if (platform::is_gpu_place(ctx.GetPlace())) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
// GPU data is copied to CPU buffer when sending,
|
||||
// free the buffer when possible.
|
||||
destroy_callback = [](void* backing) {
|
||||
platform::CUDAPinnedPlace cuda_pinned;
|
||||
memory::Free(cuda_pinned, backing);
|
||||
};
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string header;
|
||||
request.AppendToString(&header);
|
||||
auto buffer = std::unique_ptr<char[]>(new char[1024]);
|
||||
void* buf = buffer.get();
|
||||
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
|
||||
e.WriteRawBytes(std::string(header.data(), header.size()));
|
||||
// NCCLID is copied directly to the message, return bytebuffer
|
||||
// with only one slice if serializing NCCLID.
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (var->IsType<ncclUniqueId>()) {
|
||||
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
|
||||
NCCL_UNIQUE_ID_BYTES);
|
||||
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
|
||||
e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
|
||||
|
||||
// for serialize NCCL_ID
|
||||
::grpc::Slice slices(e.size());
|
||||
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
|
||||
::grpc::ByteBuffer tmp(&slices, 1);
|
||||
msg->Swap(&tmp);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
|
||||
// steal reference of tensor data
|
||||
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
|
||||
int num_slices = 2; // only SelectedRows have rows buffer
|
||||
slices[0] = ::grpc::Slice(e.size());
|
||||
memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
|
||||
slices[1] = ::grpc::Slice(
|
||||
grpc_slice_new_with_user_data(payload, payload_size, destroy_callback,
|
||||
static_cast<char*>(payload)),
|
||||
::grpc::Slice::STEAL_REF);
|
||||
|
||||
if (var->IsType<framework::SelectedRows>()) {
|
||||
auto* slr = var->GetMutable<framework::SelectedRows>();
|
||||
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
|
||||
size_t rows_memory_size =
|
||||
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
|
||||
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
|
||||
slices[2] = ::grpc::Slice(e2.size());
|
||||
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
|
||||
|
||||
slices[3] = ::grpc::Slice(
|
||||
grpc_slice_new_with_user_data(
|
||||
const_cast<void*>(
|
||||
reinterpret_cast<const void*>(slr->rows().data())),
|
||||
rows_memory_size, [](void* backing) {},
|
||||
const_cast<char*>(
|
||||
reinterpret_cast<const char*>(slr->rows().data()))),
|
||||
::grpc::Slice::STEAL_REF);
|
||||
num_slices = 4;
|
||||
}
|
||||
|
||||
::grpc::ByteBuffer tmp(&slices[0], num_slices);
|
||||
msg->Swap(&tmp);
|
||||
}
|
||||
|
||||
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope* scope,
|
||||
framework::Variable** var) {
|
||||
operators::distributed::GRPCVariableResponse resp(scope, &ctx);
|
||||
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
|
||||
*var = resp.GetVar();
|
||||
}
|
||||
|
||||
} // namespace distributed
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,50 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <sys/time.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
#include "paddle/fluid/framework/var_type.h"
|
||||
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
|
||||
|
||||
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
|
||||
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace distributed {
|
||||
|
||||
typedef void (*DestroyCallback)(void*);
|
||||
|
||||
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
|
||||
const platform::DeviceContext& ctx,
|
||||
::grpc::ByteBuffer* msg,
|
||||
const std::string& out_varname = std::string());
|
||||
|
||||
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope* scope,
|
||||
framework::Variable** var);
|
||||
|
||||
} // namespace distributed
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,58 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/framework/var_type.h"
|
||||
|
||||
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
|
||||
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
|
||||
|
||||
#include "google/protobuf/io/coded_stream.h"
|
||||
#include "google/protobuf/io/zero_copy_stream.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
|
||||
#include "paddle/fluid/operators/distributed/variable_response.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace distributed {
|
||||
|
||||
class GRPCVariableResponse : public VariableResponse {
|
||||
public:
|
||||
GRPCVariableResponse(const framework::Scope* scope,
|
||||
const platform::DeviceContext* dev_ctx,
|
||||
bool create_scope = false)
|
||||
: VariableResponse(scope, dev_ctx, create_scope) {}
|
||||
|
||||
virtual ~GRPCVariableResponse() {}
|
||||
|
||||
int Parse(Source* source) override;
|
||||
|
||||
// return:
|
||||
// 0:ok.
|
||||
// -1: unkown error.
|
||||
// other: number of error field.
|
||||
int Parse(const ::grpc::ByteBuffer& byte_buffer);
|
||||
};
|
||||
|
||||
}; // namespace distributed
|
||||
}; // namespace operators
|
||||
}; // namespace paddle
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue