Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/recordio_file_reader
commit
cfca8a3a26
Before Width: | Height: | Size: 344 KiB After Width: | Height: | Size: 344 KiB |
Before Width: | Height: | Size: 190 KiB After Width: | Height: | Size: 190 KiB |
File diff suppressed because it is too large
Load Diff
@ -1,3 +1,6 @@
|
||||
if(WITH_DISTRIBUTE)
|
||||
grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
|
||||
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
|
||||
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
|
||||
set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
|
||||
cc_test(serde_test SRCS test_serde.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc)
|
||||
endif()
|
||||
|
@ -0,0 +1,88 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
// NOTE: This file was originally created by tensorflow
|
||||
// (https://github.com/tensorflow/tensorflow/) we borrow this
|
||||
// file and did some modifications so that we can send gRPC
|
||||
// requests without too much copying of the tensor data.
|
||||
|
||||
#include "bytebuffer_stream.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
GrpcByteBufferSource::GrpcByteBufferSource() {}
|
||||
|
||||
bool GrpcByteBufferSource::Init(const grpc::ByteBuffer& src) {
|
||||
cur_ = -1;
|
||||
left_ = 0;
|
||||
ptr_ = nullptr;
|
||||
byte_count_ = 0;
|
||||
bool ok = src.Dump(&slices_).ok();
|
||||
if (!ok) {
|
||||
slices_.clear();
|
||||
}
|
||||
return ok;
|
||||
}
|
||||
|
||||
bool GrpcByteBufferSource::Next(const void** data, int* size) {
|
||||
// Use loop instead of if in case buffer contained empty slices.
|
||||
while (left_ == 0) {
|
||||
// Advance to next slice.
|
||||
cur_++;
|
||||
if (cur_ >= slices_.size()) {
|
||||
return false;
|
||||
}
|
||||
const ::grpc::Slice& s = slices_[cur_];
|
||||
left_ = s.size();
|
||||
ptr_ = reinterpret_cast<const char*>(s.begin());
|
||||
}
|
||||
|
||||
*data = ptr_;
|
||||
*size = left_;
|
||||
byte_count_ += left_;
|
||||
ptr_ += left_;
|
||||
left_ = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
void GrpcByteBufferSource::BackUp(int count) {
|
||||
ptr_ -= count;
|
||||
left_ += count;
|
||||
byte_count_ -= count;
|
||||
}
|
||||
|
||||
bool GrpcByteBufferSource::Skip(int count) {
|
||||
const void* data;
|
||||
int size;
|
||||
while (Next(&data, &size)) {
|
||||
if (size >= count) {
|
||||
BackUp(size - count);
|
||||
return true;
|
||||
}
|
||||
// size < count;
|
||||
count -= size;
|
||||
}
|
||||
// error or we have too large count;
|
||||
return false;
|
||||
}
|
||||
|
||||
google::protobuf::int64 GrpcByteBufferSource::ByteCount() const {
|
||||
return byte_count_;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,51 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
// NOTE: This file was originally created by tensorflow
|
||||
// (https://github.com/tensorflow/tensorflow/) we borrow this
|
||||
// file and did some modifications so that we can send gRPC
|
||||
// requests without too much copying of the tensor data.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <grpc++/grpc++.h>
|
||||
#include "google/protobuf/io/coded_stream.h"
|
||||
#include "google/protobuf/io/zero_copy_stream.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
|
||||
class GrpcByteBufferSource
|
||||
: public ::google::protobuf::io::ZeroCopyInputStream {
|
||||
public:
|
||||
GrpcByteBufferSource();
|
||||
bool Init(const ::grpc::ByteBuffer& src); // Can be called multiple times.
|
||||
bool Next(const void** data, int* size) override;
|
||||
void BackUp(int count) override;
|
||||
bool Skip(int count) override;
|
||||
::google::protobuf::int64 ByteCount() const override;
|
||||
|
||||
private:
|
||||
std::vector<::grpc::Slice> slices_;
|
||||
size_t cur_; // Current slice index.
|
||||
int left_; // Number of bytes in slices_[cur_] left to yield.
|
||||
const char* ptr_; // Address of next byte in slices_[cur_] to yield.
|
||||
::google::protobuf::int64 byte_count_;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,147 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
// NOTE: This file was originally created by tensorflow
|
||||
// (https://github.com/tensorflow/tensorflow/) we borrow this
|
||||
// file and did some modifications so that we can send gRPC
|
||||
// requests without too much copying of the tensor data.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <grpc++/grpc++.h>
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
char* EncodeVarint32(char* dst, uint32_t v) {
|
||||
// Operate on characters as unsigneds
|
||||
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
|
||||
static const int B = 128;
|
||||
if (v < (1 << 7)) {
|
||||
*(ptr++) = v;
|
||||
} else if (v < (1 << 14)) {
|
||||
*(ptr++) = v | B;
|
||||
*(ptr++) = v >> 7;
|
||||
} else if (v < (1 << 21)) {
|
||||
*(ptr++) = v | B;
|
||||
*(ptr++) = (v >> 7) | B;
|
||||
*(ptr++) = v >> 14;
|
||||
} else if (v < (1 << 28)) {
|
||||
*(ptr++) = v | B;
|
||||
*(ptr++) = (v >> 7) | B;
|
||||
*(ptr++) = (v >> 14) | B;
|
||||
*(ptr++) = v >> 21;
|
||||
} else {
|
||||
*(ptr++) = v | B;
|
||||
*(ptr++) = (v >> 7) | B;
|
||||
*(ptr++) = (v >> 14) | B;
|
||||
*(ptr++) = (v >> 21) | B;
|
||||
*(ptr++) = v >> 28;
|
||||
}
|
||||
return reinterpret_cast<char*>(ptr);
|
||||
}
|
||||
|
||||
char* EncodeVarint64(char* dst, uint64_t v) {
|
||||
static const int B = 128;
|
||||
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
|
||||
while (v >= B) {
|
||||
*(ptr++) = (v & (B - 1)) | B;
|
||||
v >>= 7;
|
||||
}
|
||||
*(ptr++) = static_cast<unsigned char>(v);
|
||||
return reinterpret_cast<char*>(ptr);
|
||||
}
|
||||
|
||||
int VarintLength(uint64_t v) {
|
||||
int len = 1;
|
||||
while (v >= 128) {
|
||||
v >>= 7;
|
||||
len++;
|
||||
}
|
||||
return len;
|
||||
}
|
||||
|
||||
class ProtoEncodeHelper {
|
||||
public:
|
||||
ProtoEncodeHelper(char* buf, int max_size)
|
||||
: base_(buf), p_(buf), limit_(base_ + max_size) {}
|
||||
|
||||
~ProtoEncodeHelper() {
|
||||
// Make sure callers didn't do operations that went over max_size promised
|
||||
PADDLE_ENFORCE_LE(p_, limit_);
|
||||
}
|
||||
|
||||
const char* data() const { return base_; }
|
||||
size_t size() const { return p_ - base_; }
|
||||
|
||||
void WriteUint64(int tag, uint64_t v) {
|
||||
Encode32(combine(tag, WIRETYPE_VARINT));
|
||||
Encode64(v);
|
||||
}
|
||||
void WriteBool(int tag, bool v) {
|
||||
Encode32(combine(tag, WIRETYPE_VARINT));
|
||||
EncodeBool(v);
|
||||
}
|
||||
void WriteString(int tag, const std::string& v) {
|
||||
Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED));
|
||||
Encode32(v.size());
|
||||
EncodeBytes(v.data(), v.size());
|
||||
}
|
||||
void WriteVarlengthBeginning(int tag, uint32_t len) {
|
||||
Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED));
|
||||
Encode32(len);
|
||||
}
|
||||
void WriteRawBytes(const std::string& v) { EncodeBytes(v.data(), v.size()); }
|
||||
|
||||
private:
|
||||
// Note: this module's behavior must match the protocol buffer wire encoding
|
||||
// format.
|
||||
enum {
|
||||
WIRETYPE_VARINT = 0,
|
||||
WIRETYPE_LENGTH_DELIMITED = 2,
|
||||
};
|
||||
static uint32_t combine(uint32_t tag, uint32_t type) {
|
||||
return ((tag << 3) | type);
|
||||
}
|
||||
inline void Encode32(uint32_t v) {
|
||||
if (v < 128) {
|
||||
// Fast path for single-byte values. Many of the calls will use a
|
||||
// constant value for v, so the comparison will get optimized away
|
||||
// when Encode32 is inlined into the caller.
|
||||
*p_ = v;
|
||||
p_++;
|
||||
} else {
|
||||
p_ = EncodeVarint32(p_, v);
|
||||
}
|
||||
}
|
||||
void Encode64(uint64_t v) { p_ = EncodeVarint64(p_, v); }
|
||||
void EncodeBool(bool v) {
|
||||
*p_ = (v ? 1 : 0); // Equal to varint32 encoding of 0 or 1
|
||||
p_++;
|
||||
}
|
||||
void EncodeBytes(const char* bytes, int N) {
|
||||
memcpy(p_, bytes, N);
|
||||
p_ += N;
|
||||
}
|
||||
|
||||
char* base_;
|
||||
char* p_;
|
||||
char* limit_; // Just for CHECKs
|
||||
};
|
||||
|
||||
} // detail
|
||||
} // operators
|
||||
} // paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue