Add brpc serialization support. (#11430)
parent
37c2e24511
commit
0b1c7d838c
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,84 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifdef PADDLE_WITH_BRPC_RDMA
|
||||
|
||||
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
|
||||
#include "brpc/channel.h"
|
||||
#include "brpc/rdma/rdma_helper.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace distributed {
|
||||
|
||||
RdmaMemPool& RdmaMemPool::Instance() {
|
||||
static RdmaMemPool* g_rdma_mem_pool = new RdmaMemPool();
|
||||
return *g_rdma_mem_pool;
|
||||
}
|
||||
|
||||
void* RdmaMemPool::Find(const std::string& varname, int64_t size) {
|
||||
pthread_rwlock_rdlock(&access_);
|
||||
auto it = pool_.find(varname);
|
||||
if (it == pool_.end()) {
|
||||
pthread_rwlock_unlock(&access_);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto info = it->second;
|
||||
if (info.data_size != size) {
|
||||
pthread_rwlock_unlock(&access_);
|
||||
PADDLE_ENFORCE(false, "var:%s size:%ld != %ld", varname, size,
|
||||
info.data_size);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
pthread_rwlock_unlock(&access_);
|
||||
return info.data;
|
||||
}
|
||||
|
||||
void RdmaMemPool::Register(const std::string& varname, void* data,
|
||||
int64_t data_size) {
|
||||
void* old = Find(varname, data_size);
|
||||
if (old != nullptr) {
|
||||
if (data != old) {
|
||||
PADDLE_ENFORCE(false, "var:%s data:%ld != %ld", varname, data, old);
|
||||
}
|
||||
VLOG(7) << "Find on rdma:" << varname << " data:" << data
|
||||
<< " data_size:" << data_size;
|
||||
return;
|
||||
}
|
||||
|
||||
VarInfo info;
|
||||
info.data = data;
|
||||
info.data_size = data_size;
|
||||
|
||||
pthread_rwlock_wrlock(&access_);
|
||||
pool_[varname] = info;
|
||||
pthread_rwlock_unlock(&access_);
|
||||
|
||||
if (brpc::rdma::RegisterMemoryForRdma(data, data_size)) {
|
||||
LOG(FATAL) << "register " << varname << " data:" << data
|
||||
<< " data_size:" << data_size << " error";
|
||||
}
|
||||
|
||||
VLOG(4) << "register on rdma:" << varname << " data:" << data
|
||||
<< " data_size:" << data_size;
|
||||
}
|
||||
|
||||
} // namespace distributed
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
#endif
|
@ -0,0 +1,56 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#ifdef PADDLE_WITH_BRPC_RDMA
|
||||
|
||||
#include <pthread.h> // NOLINT
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace distributed {
|
||||
|
||||
/*
|
||||
* This class is used to avoid duplicated registion of brpc::rdma.
|
||||
*/
|
||||
class RdmaMemPool {
|
||||
public:
|
||||
static RdmaMemPool& Instance();
|
||||
RdmaMemPool() : access_(PTHREAD_RWLOCK_INITIALIZER) {}
|
||||
|
||||
virtual ~RdmaMemPool() { pthread_rwlock_destroy(&access_); }
|
||||
|
||||
void Register(const std::string& varname, void* data, int64_t size);
|
||||
void* Find(const std::string& varname, int64_t size);
|
||||
|
||||
private:
|
||||
struct VarInfo {
|
||||
void* data;
|
||||
int64_t data_size;
|
||||
|
||||
VarInfo() : data(nullptr), data_size(0) {}
|
||||
};
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, VarInfo> pool_;
|
||||
pthread_rwlock_t access_;
|
||||
};
|
||||
|
||||
} // namespace distributed
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
#endif
|
@ -0,0 +1,196 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#include <nccl.h>
|
||||
#endif
|
||||
#include <sys/time.h>
|
||||
#include <thread> // NOLINT
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
|
||||
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
|
||||
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
|
||||
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
|
||||
#include "paddle/fluid/platform/profiler.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace distributed {
|
||||
|
||||
class IOBufWriter {
|
||||
public:
|
||||
static void Append(butil::IOBuf* iobuf, int k, const char* v, int64_t vlen) {
|
||||
iobuf->append(reinterpret_cast<char*>(&k), 4);
|
||||
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
|
||||
iobuf->append(v, vlen);
|
||||
}
|
||||
|
||||
static void AppendTCPZeroCopy(butil::IOBuf* iobuf, int k, const char* v,
|
||||
int64_t vlen, bool in_cuda_pinned,
|
||||
void (*destroy)(void*), void* user_data) {
|
||||
VLOG(7) << "AppendTCPZeroCopy "
|
||||
<< " k:" << k
|
||||
<< " data:" << static_cast<void*>(const_cast<char*>(v))
|
||||
<< " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned;
|
||||
|
||||
iobuf->append(reinterpret_cast<char*>(&k), 4);
|
||||
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
|
||||
|
||||
// FIXME(gongwb): use append_zerocopy
|
||||
/*
|
||||
if (in_cuda_pinned) {
|
||||
iobuf->append_zerocopy(v, vlen, IOBufWriter::FreeMemory);
|
||||
} else {
|
||||
iobuf->append_zerocopy(v, vlen, nullptr);
|
||||
}
|
||||
*/
|
||||
iobuf->append(v, vlen);
|
||||
destroy(user_data);
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_BRPC_RDMA
|
||||
static void AppendRdmaZeroCopy(const std::string varname, butil::IOBuf* iobuf,
|
||||
int k, const char* v, int64_t vlen,
|
||||
bool in_cuda_pinned, void (*destroy)(void*),
|
||||
void* user_data) {
|
||||
VLOG(7) << "AppendRdmaZeroCopy varname:" << varname << " k:" << k
|
||||
<< " data:" << static_cast<void*>(const_cast<char*>(v))
|
||||
<< " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned;
|
||||
|
||||
iobuf->append(reinterpret_cast<char*>(&k), 4);
|
||||
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
|
||||
|
||||
RdmaMemPool::Instance().Register(
|
||||
varname, static_cast<void*>(const_cast<char*>(v)), vlen);
|
||||
|
||||
// FIXME(gongwb): use append_zerocopy
|
||||
// iobuf->append_zerocopy(v, vlen, nullptr);
|
||||
iobuf->append(v, vlen);
|
||||
destroy(user_data);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
static void AppendZeroCopy(const std::string varname, butil::IOBuf* iobuf,
|
||||
int k, const char* v, int64_t vlen,
|
||||
bool in_cuda_pinned, void (*destroy)(void*),
|
||||
void* user_data) {
|
||||
#ifdef PADDLE_WITH_BRPC_RDMA
|
||||
IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned,
|
||||
destroy, user_data);
|
||||
#else
|
||||
IOBufWriter::AppendTCPZeroCopy(iobuf, k, v, vlen, in_cuda_pinned, destroy,
|
||||
user_data);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
void SerializeToIOBuf(const std::string& name, framework::Variable* var,
|
||||
const platform::DeviceContext& ctx, VarMsg* request,
|
||||
butil::IOBuf* iobuf, const std::string& out_varname,
|
||||
bool var_is_not_stable, int trainer_id,
|
||||
const std::string& table_name) {
|
||||
std::unique_ptr<TensorPayload> payload;
|
||||
|
||||
request->set_varname(name);
|
||||
request->set_trainer_id(trainer_id);
|
||||
// Note: normally the profiler is enabled in 1 trainer, hence only
|
||||
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
|
||||
// servers the trainer's profiling state so that PS can follow the
|
||||
// trainer.
|
||||
if (platform::ShouldSendProfileState()) {
|
||||
if (platform::IsProfileEnabled()) {
|
||||
request->set_profile(platform::kEnableProfiler);
|
||||
} else {
|
||||
request->set_profile(platform::kDisableProfiler);
|
||||
}
|
||||
}
|
||||
if (!out_varname.empty()) {
|
||||
request->set_out_varname(out_varname);
|
||||
}
|
||||
if (!table_name.empty()) {
|
||||
request->set_table_name(table_name);
|
||||
}
|
||||
if (var->IsType<framework::LoDTensor>()) {
|
||||
request->set_type(::sendrecv::LOD_TENSOR);
|
||||
payload.reset(new TensorPayload(GetTensorPayload(var, ctx, request)));
|
||||
} else if (var->IsType<framework::SelectedRows>()) {
|
||||
request->set_type(::sendrecv::SELECTED_ROWS);
|
||||
payload.reset(new TensorPayload(GetSelectedRowsPayload(var, ctx, request)));
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
} else if (var->IsType<ncclUniqueId>()) {
|
||||
request->set_type(::sendrecv::NCCL_ID);
|
||||
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
|
||||
// TODO(gongwb): use append_zero to avoid data copy.
|
||||
IOBufWriter::Append(iobuf,
|
||||
sendrecv::VariableMessage::kSerializedFieldNumber,
|
||||
uid.internal, NCCL_UNIQUE_ID_BYTES);
|
||||
return;
|
||||
#endif
|
||||
} else {
|
||||
PADDLE_THROW("Serialize does not support type: %s",
|
||||
typeid(var->Type()).name());
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_NOT_NULL(payload);
|
||||
|
||||
// FIXME(gongwb): it seems that can use zero copy.
|
||||
if (var_is_not_stable) {
|
||||
IOBufWriter::Append(
|
||||
iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
|
||||
static_cast<const char*>(payload->ptr()), payload->memory_size());
|
||||
} else {
|
||||
if (platform::is_gpu_place(ctx.GetPlace())) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
IOBufWriter::AppendZeroCopy(
|
||||
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
|
||||
static_cast<const char*>(payload->ptr()), payload->memory_size(),
|
||||
true, SerializeDestroyCallback, static_cast<void*>(payload.get()));
|
||||
payload.release();
|
||||
#endif
|
||||
} else {
|
||||
IOBufWriter::AppendZeroCopy(
|
||||
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
|
||||
static_cast<const char*>(payload->ptr()), payload->memory_size(),
|
||||
false, SerializeDestroyCallback, static_cast<void*>(payload.get()));
|
||||
payload.release();
|
||||
}
|
||||
}
|
||||
|
||||
if (var->IsType<framework::SelectedRows>()) {
|
||||
auto* slr = var->GetMutable<framework::SelectedRows>();
|
||||
size_t rows_memory_size =
|
||||
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
|
||||
|
||||
IOBufWriter::Append(iobuf, ::sendrecv::VariableMessage::kRowsFieldNumber,
|
||||
reinterpret_cast<const char*>(slr->rows().data()),
|
||||
static_cast<int64_t>(rows_memory_size));
|
||||
}
|
||||
}
|
||||
|
||||
void DeserializeFromIOBuf(const ::sendrecv::VariableMessage& meta,
|
||||
const butil::IOBuf& iobuf,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope* scope,
|
||||
framework::Variable** var, int* trainer_id) {
|
||||
operators::distributed::BRPCVariableResponse resp(scope, &ctx);
|
||||
PADDLE_ENFORCE(resp.Parse(iobuf, meta) == 0, "parse iobuf to tensor error!");
|
||||
*var = resp.GetVar();
|
||||
*trainer_id = resp.GetTrainerId();
|
||||
}
|
||||
|
||||
} // namespace distributed
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,49 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sys/time.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "brpc/channel.h"
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
#include "paddle/fluid/framework/var_type.h"
|
||||
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
|
||||
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace distributed {
|
||||
|
||||
void SerializeToIOBuf(const std::string& name, framework::Variable* var,
|
||||
const platform::DeviceContext& ctx, VarMsg* request,
|
||||
butil::IOBuf* iobuf, const std::string& out_varname,
|
||||
bool var_is_not_stable, const int trainer_id = 0,
|
||||
const std::string& table_name = std::string());
|
||||
|
||||
void DeserializeFromIOBuf(const VarMsg& meta, const butil::IOBuf& iobuf,
|
||||
const platform::DeviceContext& ctx,
|
||||
const framework::Scope* scope,
|
||||
framework::Variable** var, int* trainer_id);
|
||||
|
||||
} // namespace distributed
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,175 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <unistd.h>
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
|
||||
#include "brpc/channel.h"
|
||||
#include "google/protobuf/text_format.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
#include "paddle/fluid/framework/variable.h"
|
||||
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
|
||||
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
|
||||
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
|
||||
#include "paddle/fluid/operators/distributed/variable_response.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
#include "paddle/fluid/string/printf.h"
|
||||
|
||||
namespace framework = paddle::framework;
|
||||
namespace platform = paddle::platform;
|
||||
namespace operators = paddle::operators;
|
||||
namespace math = paddle::operators::math;
|
||||
namespace memory = paddle::memory;
|
||||
|
||||
void RunSerdeTestSelectedRows(platform::Place place) {
|
||||
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
||||
auto& ctx = *pool.Get(place);
|
||||
|
||||
butil::IOBuf iobuf;
|
||||
sendrecv::VariableMessage msg;
|
||||
int tensor_numel = 564 * 128;
|
||||
|
||||
// serialize var to IOBuf
|
||||
{
|
||||
framework::Variable var;
|
||||
auto* slr = var.GetMutable<framework::SelectedRows>();
|
||||
slr->set_height(1000);
|
||||
auto* tensor = slr->mutable_value();
|
||||
auto* rows = slr->mutable_rows();
|
||||
tensor->Resize(framework::make_ddim({564, 128}));
|
||||
tensor->mutable_data<float>(place);
|
||||
math::set_constant(ctx, tensor, 32.7);
|
||||
for (int i = 0; i < 564; ++i) rows->push_back(i);
|
||||
|
||||
operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf,
|
||||
"", false);
|
||||
}
|
||||
|
||||
// desrialize
|
||||
{
|
||||
framework::Scope scope;
|
||||
scope.Var("myvar");
|
||||
operators::distributed::BRPCVariableResponse resp(&scope, &ctx);
|
||||
EXPECT_EQ(resp.Parse(iobuf, msg), 0);
|
||||
|
||||
framework::Variable* var2 = resp.GetVar();
|
||||
|
||||
auto* slr2 = var2->GetMutable<framework::SelectedRows>();
|
||||
auto* tensor2 = slr2->mutable_value();
|
||||
auto* rows2 = slr2->mutable_rows();
|
||||
float* tensor_data2 = nullptr;
|
||||
framework::Tensor tmp_tensor;
|
||||
|
||||
if (platform::is_gpu_place(ctx.GetPlace())) {
|
||||
platform::CPUPlace cpu;
|
||||
framework::TensorCopy(*tensor2, cpu, &tmp_tensor);
|
||||
tensor_data2 = tmp_tensor.data<float>();
|
||||
} else {
|
||||
tensor_data2 = const_cast<float*>(tensor2->data<float>());
|
||||
}
|
||||
const int64_t* rows_data2 = rows2->data();
|
||||
|
||||
for (int i = 0; i < tensor_numel; ++i) {
|
||||
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
|
||||
}
|
||||
for (size_t i = 0; i < rows2->size(); ++i) {
|
||||
EXPECT_EQ(rows_data2[i], static_cast<int64_t>(i));
|
||||
}
|
||||
EXPECT_EQ(slr2->height(), 1000);
|
||||
}
|
||||
}
|
||||
|
||||
void RunTestLodTensor(platform::Place place) {
|
||||
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
||||
auto& ctx = *pool.Get(place);
|
||||
|
||||
// serialize var to ByteBuffer
|
||||
butil::IOBuf iobuf;
|
||||
sendrecv::VariableMessage msg;
|
||||
int tensor_numel = 512 * 8 * 4 * 2;
|
||||
{
|
||||
framework::Variable var;
|
||||
auto* tensor = var.GetMutable<framework::LoDTensor>();
|
||||
tensor->Resize(framework::make_ddim({512, 8, 4, 2}));
|
||||
framework::LoD lod;
|
||||
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
|
||||
tensor->set_lod(lod);
|
||||
tensor->mutable_data<float>(place);
|
||||
math::set_constant(ctx, tensor, 31.9);
|
||||
|
||||
operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf,
|
||||
"", false);
|
||||
}
|
||||
|
||||
// check sendrecv::VariableMessage meta data
|
||||
{
|
||||
EXPECT_EQ(msg.varname(), "myvar");
|
||||
EXPECT_EQ(msg.type(), 0);
|
||||
EXPECT_EQ(msg.dims()[0], 512);
|
||||
EXPECT_EQ(msg.dims()[1], 8);
|
||||
EXPECT_EQ(msg.dims()[2], 4);
|
||||
EXPECT_EQ(msg.dims()[3], 2);
|
||||
EXPECT_EQ(msg.lod_level(), 1);
|
||||
EXPECT_EQ(msg.lod(0).lod_data(0), 1);
|
||||
EXPECT_EQ(msg.lod(0).lod_data(1), 3);
|
||||
EXPECT_EQ(msg.lod(0).lod_data(2), 8);
|
||||
}
|
||||
|
||||
// deserialize
|
||||
{
|
||||
framework::Scope scope;
|
||||
scope.Var("myvar");
|
||||
operators::distributed::BRPCVariableResponse resp(&scope, &ctx);
|
||||
EXPECT_EQ(resp.Parse(iobuf, msg), 0);
|
||||
|
||||
framework::Variable* var2 = resp.GetVar();
|
||||
|
||||
auto tensor2 = var2->Get<framework::LoDTensor>();
|
||||
float* tensor_data2 = nullptr;
|
||||
framework::Tensor tmp_tensor;
|
||||
|
||||
if (platform::is_gpu_place(ctx.GetPlace())) {
|
||||
platform::CPUPlace cpu;
|
||||
framework::TensorCopy(tensor2, cpu, &tmp_tensor);
|
||||
tensor_data2 = tmp_tensor.data<float>();
|
||||
} else {
|
||||
tensor_data2 = const_cast<float*>(tensor2.data<float>());
|
||||
}
|
||||
|
||||
for (int i = 0; i < tensor_numel; ++i)
|
||||
EXPECT_FLOAT_EQ(tensor_data2[i], 31.9);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(LodTensor, Run) {
|
||||
platform::CPUPlace place;
|
||||
RunTestLodTensor(place);
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
platform::CUDAPlace gpu(0);
|
||||
RunTestLodTensor(gpu);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(SelectedRows, Run) {
|
||||
platform::CPUPlace place;
|
||||
RunSerdeTestSelectedRows(place);
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
platform::CUDAPlace gpu;
|
||||
RunSerdeTestSelectedRows(gpu);
|
||||
#endif
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,73 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
|
||||
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace distributed {
|
||||
|
||||
namespace pb = ::google::protobuf;
|
||||
using vr = ::sendrecv::VariableMessage;
|
||||
|
||||
int BRPCVariableResponse::Parse(Source* source) {
|
||||
pb::io::ZeroCopyInputStream* input_stream = source->contents();
|
||||
pb::io::CodedInputStream input(input_stream);
|
||||
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
|
||||
|
||||
while (1) {
|
||||
unsigned int tag = 0;
|
||||
if (!input.ReadLittleEndian32(&tag)) {
|
||||
break;
|
||||
}
|
||||
|
||||
uint64_t num_bytes = 0;
|
||||
if (!input.ReadLittleEndian64(&num_bytes)) {
|
||||
break;
|
||||
}
|
||||
|
||||
int field = static_cast<int>(tag);
|
||||
int ret = field == 0 ? -1 : field;
|
||||
switch (field) {
|
||||
case vr::kSerializedFieldNumber: {
|
||||
if (!ProcSerializedField(field, &input, num_bytes)) {
|
||||
return ret;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case vr::kRowsFieldNumber: {
|
||||
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
|
||||
meta_.type() == sendrecv::LOD_TENSOR) &&
|
||||
meta_.varname() != "",
|
||||
"meta info should be got first!");
|
||||
|
||||
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
|
||||
return ret;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
PADDLE_ENFORCE(false, "not surpported %u fieldnumber", field);
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
} // namespace distributed
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,67 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "brpc/channel.h"
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/framework/var_type.h"
|
||||
|
||||
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
|
||||
|
||||
#include "google/protobuf/io/coded_stream.h"
|
||||
#include "google/protobuf/io/zero_copy_stream.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/operators/distributed/variable_response.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace distributed {
|
||||
|
||||
class BRPCSourceWrapper : public Source {
|
||||
public:
|
||||
explicit BRPCSourceWrapper(const butil::IOBuf& iobuf) : source_(iobuf) {}
|
||||
::google::protobuf::io::ZeroCopyInputStream* contents() override {
|
||||
return &source_;
|
||||
}
|
||||
|
||||
private:
|
||||
butil::IOBufAsZeroCopyInputStream source_;
|
||||
};
|
||||
|
||||
class BRPCVariableResponse : public VariableResponse {
|
||||
public:
|
||||
BRPCVariableResponse(const framework::Scope* scope,
|
||||
const platform::DeviceContext* dev_ctx,
|
||||
bool create_scope = false)
|
||||
: VariableResponse(scope, dev_ctx, create_scope) {}
|
||||
|
||||
virtual ~BRPCVariableResponse() {}
|
||||
|
||||
// parse attachment from iobuf
|
||||
int Parse(Source* source) override;
|
||||
int Parse(const butil::IOBuf& iobuf, const sendrecv::VariableMessage& meta) {
|
||||
BRPCSourceWrapper wrapper(iobuf);
|
||||
return VariableResponse::Parse(&wrapper, meta);
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace distributed
|
||||
}; // namespace operators
|
||||
}; // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue