Reuduce memory copy when communication between trainer and pserver. (#9271)
parent
b594251f89
commit
990d6396fe
@ -1,6 +1,8 @@
|
||||
if(WITH_DISTRIBUTE)
|
||||
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
|
||||
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
|
||||
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
|
||||
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
|
||||
set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
cc_test(serde_test SRCS test_serde.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc)
|
||||
cc_test(serde_test SRCS test_serde.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
|
||||
cares zlib protobuf sendrecvop_grpc)
|
||||
endif()
|
||||
|
@ -0,0 +1,118 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <grpc++/impl/codegen/async_stream.h>
|
||||
#include <grpc++/impl/codegen/async_unary_call.h>
|
||||
#include <grpc++/impl/codegen/proto_utils.h>
|
||||
#include <grpc++/impl/codegen/rpc_method.h>
|
||||
#include <grpc++/impl/codegen/service_type.h>
|
||||
#include <grpc++/impl/codegen/status.h>
|
||||
#include <grpc++/impl/codegen/stub_options.h>
|
||||
#include <grpc++/impl/codegen/sync_stream.h>
|
||||
#include <grpc++/support/byte_buffer.h>
|
||||
#include "paddle/fluid/operators/detail/variable_response.h"
|
||||
|
||||
// NOTE: This method was originally created by tensorflow
|
||||
// (https://github.com/tensorflow/tensorflow/) we borrow this
|
||||
// method and did some modifications so that we can parse gRPC
|
||||
// requests without too much copying of the tensor data.
|
||||
|
||||
namespace grpc {
|
||||
class CompletionQueue;
|
||||
class Channel;
|
||||
class RpcService;
|
||||
class ServerCompletionQueue;
|
||||
class ServerContext;
|
||||
|
||||
// Support parsing/unparsing of tensorflow::VariableResponse.
|
||||
// Wire-format is identical to RecvVariableResponse.
|
||||
template <>
|
||||
class SerializationTraits<paddle::operators::detail::VariableResponse> {
|
||||
public:
|
||||
static Status Serialize(
|
||||
const paddle::operators::detail::VariableResponse& msg,
|
||||
grpc_byte_buffer** bp, bool* own_buffer) {
|
||||
PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!");
|
||||
return Status();
|
||||
}
|
||||
static Status Deserialize(grpc_byte_buffer* buffer,
|
||||
paddle::operators::detail::VariableResponse* msg,
|
||||
int max_message_size = INT_MAX) {
|
||||
if (buffer == nullptr) {
|
||||
return Status(StatusCode::INTERNAL, "No payload");
|
||||
}
|
||||
|
||||
Status result = g_core_codegen_interface->ok();
|
||||
if (result.ok()) {
|
||||
paddle::operators::detail::GrpcByteSource source(buffer);
|
||||
int ret = msg->Parse(&source);
|
||||
if (ret != 0) {
|
||||
result = Status(StatusCode::INTERNAL, "VariableResponse parse error");
|
||||
}
|
||||
}
|
||||
g_core_codegen_interface->grpc_byte_buffer_destroy(buffer);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
} // namespace grpc
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
enum class GrpcMethod {
|
||||
kSendVariable,
|
||||
kGetVariable,
|
||||
};
|
||||
|
||||
static const int kGrpcNumMethods =
|
||||
static_cast<int>(GrpcMethod::kGetVariable) + 1;
|
||||
|
||||
inline const char* GrpcMethodName(GrpcMethod id) {
|
||||
switch (id) {
|
||||
case GrpcMethod::kSendVariable:
|
||||
return "/sendrecv.SendRecvService/SendVariable";
|
||||
case GrpcMethod::kGetVariable:
|
||||
return "/sendrecv.SendRecvService/GetVariable";
|
||||
}
|
||||
|
||||
// Shouldn't be reached.
|
||||
PADDLE_ENFORCE(false, "Invalid id: not found valid method name");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
class GrpcService final {
|
||||
public:
|
||||
class AsyncService : public ::grpc::Service {
|
||||
public:
|
||||
AsyncService() {
|
||||
for (int i = 0; i < kGrpcNumMethods; ++i) {
|
||||
AddMethod(new ::grpc::internal::RpcServiceMethod(
|
||||
GrpcMethodName(static_cast<GrpcMethod>(i)),
|
||||
::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
|
||||
::grpc::Service::MarkMethodAsync(i);
|
||||
}
|
||||
}
|
||||
virtual ~AsyncService() {}
|
||||
|
||||
// Make RequestAsyncUnary public for grpc_call.h
|
||||
using ::grpc::Service::RequestAsyncUnary;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operator
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,81 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/framework/var_type.h"
|
||||
|
||||
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
|
||||
#include "paddle/fluid/operators/detail/send_recv.pb.h"
|
||||
|
||||
#include "google/protobuf/io/coded_stream.h"
|
||||
#include "google/protobuf/io/zero_copy_stream.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
class VariableResponse {
|
||||
public:
|
||||
VariableResponse(const framework::Scope* scope,
|
||||
const platform::DeviceContext* dev_ctx)
|
||||
: scope_(scope), dev_ctx_(dev_ctx){};
|
||||
|
||||
virtual ~VariableResponse(){};
|
||||
|
||||
// return:
|
||||
// 0:ok.
|
||||
// -1: unkown error.
|
||||
// other: number of error field.
|
||||
int Parse(Source* source);
|
||||
|
||||
// return:
|
||||
// 0:ok.
|
||||
// -1: unkown error.
|
||||
// other: number of error field.
|
||||
int Parse(const ::grpc::ByteBuffer& byte_buffer);
|
||||
|
||||
inline std::string Varname() { return meta_.varname(); }
|
||||
|
||||
// should call parse first.
|
||||
framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }
|
||||
|
||||
private:
|
||||
bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input,
|
||||
const platform::DeviceContext& ctx,
|
||||
framework::DDim& dims, int length);
|
||||
|
||||
bool CopySelectRowsData(::google::protobuf::io::CodedInputStream* input,
|
||||
const platform::DeviceContext& ctx, int length);
|
||||
|
||||
bool CopyLodTensorData(::google::protobuf::io::CodedInputStream* input,
|
||||
const platform::DeviceContext& ctx,
|
||||
framework::DDim& dims, int length);
|
||||
|
||||
private:
|
||||
const framework::Scope* scope_;
|
||||
const platform::DeviceContext* dev_ctx_;
|
||||
// only Skeleton
|
||||
sendrecv::VariableMessage meta_;
|
||||
};
|
||||
|
||||
}; // namespace detail
|
||||
}; // namespace operators
|
||||
}; // namespace paddle
|
Loading…
Reference in new issue