Performance/zero copy variable seriralization (#8839)

shanyi15-patch-2
武毅 7 years ago committed by gongweibao
parent 12fc76e1b5
commit 45af8c1e99

@ -187,7 +187,6 @@ bool TensorContainsInf(const framework::Tensor& tensor) {
void TensorToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx) {
// TODO(typhoonzero): serialize to ostream
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char*>(&version), sizeof(version));

@ -1,3 +1,6 @@
if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS test_serde.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc)
endif()

@ -0,0 +1,88 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// NOTE: This file was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
#include "bytebuffer_stream.h"
namespace paddle {
namespace operators {
namespace detail {
GrpcByteBufferSource::GrpcByteBufferSource() {}
bool GrpcByteBufferSource::Init(const grpc::ByteBuffer& src) {
cur_ = -1;
left_ = 0;
ptr_ = nullptr;
byte_count_ = 0;
bool ok = src.Dump(&slices_).ok();
if (!ok) {
slices_.clear();
}
return ok;
}
bool GrpcByteBufferSource::Next(const void** data, int* size) {
// Use loop instead of if in case buffer contained empty slices.
while (left_ == 0) {
// Advance to next slice.
cur_++;
if (cur_ >= slices_.size()) {
return false;
}
const ::grpc::Slice& s = slices_[cur_];
left_ = s.size();
ptr_ = reinterpret_cast<const char*>(s.begin());
}
*data = ptr_;
*size = left_;
byte_count_ += left_;
ptr_ += left_;
left_ = 0;
return true;
}
void GrpcByteBufferSource::BackUp(int count) {
ptr_ -= count;
left_ += count;
byte_count_ -= count;
}
bool GrpcByteBufferSource::Skip(int count) {
const void* data;
int size;
while (Next(&data, &size)) {
if (size >= count) {
BackUp(size - count);
return true;
}
// size < count;
count -= size;
}
// error or we have too large count;
return false;
}
google::protobuf::int64 GrpcByteBufferSource::ByteCount() const {
return byte_count_;
}
} // namespace detail
} // namespace operators
} // namespace paddle

@ -0,0 +1,51 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// NOTE: This file was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
#pragma once
#include <grpc++/grpc++.h>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
namespace paddle {
namespace operators {
namespace detail {
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
class GrpcByteBufferSource
: public ::google::protobuf::io::ZeroCopyInputStream {
public:
GrpcByteBufferSource();
bool Init(const ::grpc::ByteBuffer& src); // Can be called multiple times.
bool Next(const void** data, int* size) override;
void BackUp(int count) override;
bool Skip(int count) override;
::google::protobuf::int64 ByteCount() const override;
private:
std::vector<::grpc::Slice> slices_;
size_t cur_; // Current slice index.
int left_; // Number of bytes in slices_[cur_] left to yield.
const char* ptr_; // Address of next byte in slices_[cur_] to yield.
::google::protobuf::int64 byte_count_;
};
} // namespace detail
} // namespace operators
} // namespace paddle

@ -0,0 +1,147 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// NOTE: This file was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
#pragma once
#include <grpc++/grpc++.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace detail {
char* EncodeVarint32(char* dst, uint32_t v) {
// Operate on characters as unsigneds
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
static const int B = 128;
if (v < (1 << 7)) {
*(ptr++) = v;
} else if (v < (1 << 14)) {
*(ptr++) = v | B;
*(ptr++) = v >> 7;
} else if (v < (1 << 21)) {
*(ptr++) = v | B;
*(ptr++) = (v >> 7) | B;
*(ptr++) = v >> 14;
} else if (v < (1 << 28)) {
*(ptr++) = v | B;
*(ptr++) = (v >> 7) | B;
*(ptr++) = (v >> 14) | B;
*(ptr++) = v >> 21;
} else {
*(ptr++) = v | B;
*(ptr++) = (v >> 7) | B;
*(ptr++) = (v >> 14) | B;
*(ptr++) = (v >> 21) | B;
*(ptr++) = v >> 28;
}
return reinterpret_cast<char*>(ptr);
}
char* EncodeVarint64(char* dst, uint64_t v) {
static const int B = 128;
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
while (v >= B) {
*(ptr++) = (v & (B - 1)) | B;
v >>= 7;
}
*(ptr++) = static_cast<unsigned char>(v);
return reinterpret_cast<char*>(ptr);
}
int VarintLength(uint64_t v) {
int len = 1;
while (v >= 128) {
v >>= 7;
len++;
}
return len;
}
class ProtoEncodeHelper {
public:
ProtoEncodeHelper(char* buf, int max_size)
: base_(buf), p_(buf), limit_(base_ + max_size) {}
~ProtoEncodeHelper() {
// Make sure callers didn't do operations that went over max_size promised
PADDLE_ENFORCE_LE(p_, limit_);
}
const char* data() const { return base_; }
size_t size() const { return p_ - base_; }
void WriteUint64(int tag, uint64_t v) {
Encode32(combine(tag, WIRETYPE_VARINT));
Encode64(v);
}
void WriteBool(int tag, bool v) {
Encode32(combine(tag, WIRETYPE_VARINT));
EncodeBool(v);
}
void WriteString(int tag, const std::string& v) {
Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED));
Encode32(v.size());
EncodeBytes(v.data(), v.size());
}
void WriteVarlengthBeginning(int tag, uint32_t len) {
Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED));
Encode32(len);
}
void WriteRawBytes(const std::string& v) { EncodeBytes(v.data(), v.size()); }
private:
// Note: this module's behavior must match the protocol buffer wire encoding
// format.
enum {
WIRETYPE_VARINT = 0,
WIRETYPE_LENGTH_DELIMITED = 2,
};
static uint32_t combine(uint32_t tag, uint32_t type) {
return ((tag << 3) | type);
}
inline void Encode32(uint32_t v) {
if (v < 128) {
// Fast path for single-byte values. Many of the calls will use a
// constant value for v, so the comparison will get optimized away
// when Encode32 is inlined into the caller.
*p_ = v;
p_++;
} else {
p_ = EncodeVarint32(p_, v);
}
}
void Encode64(uint64_t v) { p_ = EncodeVarint64(p_, v); }
void EncodeBool(bool v) {
*p_ = (v ? 1 : 0); // Equal to varint32 encoding of 0 or 1
p_++;
}
void EncodeBytes(const char* bytes, int N) {
memcpy(p_, bytes, N);
p_ += N;
}
char* base_;
char* p_;
char* limit_; // Just for CHECKs
};
} // detail
} // operators
} // paddle

@ -33,10 +33,34 @@ enum VarType {
}
message VariableMessage {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message LodData { repeated int64 lod_data = 1; }
string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2;
bytes serialized = 3;
// bool persistable is not needed for sending.
// tensor info:
Type data_type = 3;
repeated int64 dims = 4;
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// tensor data
bytes serialized = 7;
// selected_rows data
bytes rows = 8;
}
message VoidMessage {}

File diff suppressed because it is too large Load Diff

@ -33,6 +33,14 @@ namespace detail {
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
typedef void (*DestroyCallback)(void*);
inline int64_t GetTimestamp() {
return std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
}
void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg);
@ -40,6 +48,32 @@ void SerializeToMessage(const std::string& name, const framework::Variable* var,
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
const platform::DeviceContext& ctx,
framework::Variable* var);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg);
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
framework::Variable* var);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) {
switch (type) {
case sendrecv::VariableMessage::FP32:
return typeid(float); // NOLINT
case sendrecv::VariableMessage::FP64:
return typeid(double); // NOLINT
case sendrecv::VariableMessage::INT32:
return typeid(int); // NOLINT
case sendrecv::VariableMessage::INT64:
return typeid(int64_t); // NOLINT
case sendrecv::VariableMessage::BOOL:
return typeid(bool); // NOLINT
default:
PADDLE_THROW("Not support type %d", type);
}
}
} // namespace detail
} // namespace operators
} // namespace paddle

@ -0,0 +1,195 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;
void RunSerdeTestTensor(platform::Place place) {
// serialize var to ByteBuffer
framework::Variable var;
auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({4, 8, 4, 2}));
framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
int tensor_numel = 4 * 8 * 4 * 2;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
float* orig_tensor_data = tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 31.9);
::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), 0);
// deserialize
std::vector<::grpc::Slice> slices;
(void)msg.Dump(&slices);
std::string tmp;
for (const auto& s : slices) {
tmp.append(reinterpret_cast<const char*>(s.begin()), s.size());
}
sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp));
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 0);
EXPECT_EQ(varmsg.dims()[0], 4);
EXPECT_EQ(varmsg.dims()[1], 8);
EXPECT_EQ(varmsg.dims()[2], 4);
EXPECT_EQ(varmsg.dims()[3], 2);
EXPECT_EQ(varmsg.lod_level(), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
const float* tensor_data =
reinterpret_cast<const float*>(varmsg.serialized().data());
for (int i = 0; i < varmsg.serialized().size(); ++i) {
printf("%02X ", varmsg.serialized().data()[i]);
}
printf("\n");
for (int i = 0; i < tensor_numel; ++i) {
std::cout << "#####tensor data: " << tensor_data[i] << std::endl;
EXPECT_EQ(tensor_data[i], orig_tensor_data[i]);
std::cout << "test end 1 " << std::endl;
}
std::cout << "tensor data end " << std::endl;
// deserialize zero-copy
framework::Variable var2;
operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
auto tensor2 = var2.Get<framework::LoDTensor>();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2.data<float>());
}
EXPECT_EQ(varmsg.lod_level(), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
for (int i = 0; i < tensor_numel; ++i)
EXPECT_EQ(tensor_data2[i], orig_tensor_data[i]);
}
void RunSerdeTestSelectedRows(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// serialize var to ByteBuffer
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({2, 10}));
int tensor_numel = 2 * 10;
float* orig_tensor_data = tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 32.7);
rows->push_back(3);
rows->push_back(10);
::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), 0);
// deserialize
std::vector<::grpc::Slice> slices;
(void)msg.Dump(&slices);
std::string tmp;
for (const auto& s : slices) {
tmp.append(reinterpret_cast<const char*>(s.begin()), s.size());
}
sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp));
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 1);
const float* tensor_data =
reinterpret_cast<const float*>(varmsg.serialized().data());
const int64_t* rows_data =
reinterpret_cast<const int64_t*>(varmsg.rows().data());
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_EQ(tensor_data[i], orig_tensor_data[i]);
}
EXPECT_EQ(rows_data[0], 3);
EXPECT_EQ(rows_data[1], 10);
// deserialize zero-copy
framework::Variable var2;
operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
auto* slr2 = var2.GetMutable<framework::SelectedRows>();
auto* tensor2 = slr2->mutable_value();
auto* rows2 = slr2->mutable_rows();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(*tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2->data<float>());
}
const int64_t* rows_data2 = rows2->data();
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_EQ(tensor_data2[i], orig_tensor_data[i]);
}
EXPECT_EQ(rows_data2[0], 3);
EXPECT_EQ(rows_data2[1], 10);
}
// TEST(SelectedRows, CPU) {
// platform::CPUPlace place;
// RunSerdeTestSelectedRows(place);
// }
// TEST(SelectedRows, GPU) {
// platform::CUDAPlace place;
// RunSerdeTestSelectedRows(place);
// }
TEST(Tensor, CPU) {
platform::CPUPlace place;
RunSerdeTestTensor(place);
}
TEST(Tensor, GPU) {
platform::CUDAPlace place;
RunSerdeTestTensor(place);
}
Loading…
Cancel
Save