add_depthwiseConv_op_gpu
commit
dc488c17d1
@ -0,0 +1,82 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/framework/data_layout_transform.h"
|
||||||
|
|
||||||
|
#include "paddle/framework/tensor.h"
|
||||||
|
#include "paddle/operators/math/math_function.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
struct CastDataLayout {
|
||||||
|
CastDataLayout(const platform::DeviceContext* ctx,
|
||||||
|
const std::vector<int>& axis, const framework::Tensor& in,
|
||||||
|
framework::Tensor* out)
|
||||||
|
: in_(in), out_(out), ctx_(ctx), axis_(axis) {}
|
||||||
|
const framework::Tensor in_;
|
||||||
|
framework::Tensor* out_;
|
||||||
|
const platform::DeviceContext* ctx_;
|
||||||
|
const std::vector<int> axis_;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator()() {
|
||||||
|
auto place = ctx_->GetPlace();
|
||||||
|
|
||||||
|
if (platform::is_cpu_place(place)) {
|
||||||
|
operators::math::Transpose<platform::CPUDeviceContext, T, 4> trans4;
|
||||||
|
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
|
||||||
|
trans4(*context, in_, out_, axis_);
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW("Unsupport CPU <-> GPU!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TransDataLayout(const std::vector<int>& axis,
|
||||||
|
const platform::DeviceContext* ctx,
|
||||||
|
const KernelTypePair& kernel_pair, const Variable& in,
|
||||||
|
Variable* out) {
|
||||||
|
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only support Tensor transform!.");
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
platform::places_are_same_class(kernel_pair.first.place_,
|
||||||
|
kernel_pair.second.place_),
|
||||||
|
"TransDataLayout only support DataLayout transform on same place!");
|
||||||
|
PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_,
|
||||||
|
"TransDataLayout only support Datatype are same!");
|
||||||
|
|
||||||
|
auto src = in.Get<Tensor>();
|
||||||
|
auto* dst = out->GetMutable<Tensor>();
|
||||||
|
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");
|
||||||
|
|
||||||
|
auto src_dim = src.dims();
|
||||||
|
std::vector<int64_t> dst_dim;
|
||||||
|
|
||||||
|
dst_dim.resize(axis.size());
|
||||||
|
for (size_t i = 0; i < axis.size(); i++) {
|
||||||
|
dst_dim[i] = src_dim[axis[i]];
|
||||||
|
}
|
||||||
|
|
||||||
|
dst->Resize(make_ddim(dst_dim));
|
||||||
|
auto place = kernel_pair.second.place_;
|
||||||
|
dst->mutable_data(place, src.type());
|
||||||
|
|
||||||
|
auto src_type = kernel_pair.first.data_type_;
|
||||||
|
framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst));
|
||||||
|
|
||||||
|
dst->set_layout(kernel_pair.second.data_layout_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/framework/op_kernel_type.h"
|
||||||
|
#include "paddle/framework/variable.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
|
||||||
|
|
||||||
|
void TransDataLayout(const std::vector<int>& axis,
|
||||||
|
const platform::DeviceContext* ctx,
|
||||||
|
const KernelTypePair& kernel_pair, const Variable& in,
|
||||||
|
Variable* out);
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -1,168 +0,0 @@
|
|||||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License. */
|
|
||||||
#include <array>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
|
|
||||||
#include "paddle/framework/data_transform.h"
|
|
||||||
#include "paddle/platform/device_context.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace framework {
|
|
||||||
using namespace platform;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief cross validation of different kernel type transform
|
|
||||||
* We use four bit map represent different combination.
|
|
||||||
* If the field has multiple possible value, only choose two of them.
|
|
||||||
* For DataType, only test the FP32(float), FP64(double).
|
|
||||||
* e.g. 0000 -> FP32, CPUPlace, kNHWC, kPlain
|
|
||||||
* 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN
|
|
||||||
*/
|
|
||||||
|
|
||||||
std::array<proto::DataType, 2> kDataType = {
|
|
||||||
{proto::DataType::FP32, proto::DataType::FP64}};
|
|
||||||
|
|
||||||
std::array<Place, 2> kPlace = {{CPUPlace(), CUDAPlace(0)}};
|
|
||||||
|
|
||||||
std::array<DataLayout, 2> kDataLayout = {{
|
|
||||||
DataLayout::kNHWC, DataLayout::kNCHW,
|
|
||||||
}};
|
|
||||||
|
|
||||||
std::array<LibraryType, 2> kLibraryType = {{
|
|
||||||
LibraryType::kPlain, LibraryType::kMKLDNN,
|
|
||||||
}};
|
|
||||||
|
|
||||||
OpKernelType GenFromBit(const std::vector<bool> bits) {
|
|
||||||
return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]],
|
|
||||||
kLibraryType[bits[3]]);
|
|
||||||
}
|
|
||||||
|
|
||||||
int test_value = 0;
|
|
||||||
|
|
||||||
auto kernel0 = GenFromBit({0, 0, 0, 0});
|
|
||||||
auto kernel1 = GenFromBit({0, 0, 0, 1});
|
|
||||||
auto kernel2 = GenFromBit({0, 0, 1, 0});
|
|
||||||
auto kernel3 = GenFromBit({0, 0, 1, 1});
|
|
||||||
|
|
||||||
void TransDataType_t(const platform::DeviceContext* ctx,
|
|
||||||
const KernelTypePair& p, const Variable& in,
|
|
||||||
Variable* out) {
|
|
||||||
test_value++;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TransDataLayout_t(const platform::DeviceContext* ctx,
|
|
||||||
const KernelTypePair& p, const Variable& in,
|
|
||||||
Variable* out) {
|
|
||||||
test_value--;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TransLibraryType_t(const platform::DeviceContext* ctx,
|
|
||||||
const KernelTypePair& p, const Variable& in,
|
|
||||||
Variable* out) {
|
|
||||||
test_value += 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace framework
|
|
||||||
} // namespace paddle
|
|
||||||
|
|
||||||
namespace frw = paddle::framework;
|
|
||||||
|
|
||||||
REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel1, frw::TransDataType_t);
|
|
||||||
REGISTER_DATA_TRANSFORM_FN(frw::kernel1, frw::kernel2, frw::TransDataLayout_t);
|
|
||||||
REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel2, frw::TransLibraryType_t);
|
|
||||||
|
|
||||||
TEST(DataTransform, Register) {
|
|
||||||
using namespace paddle::framework;
|
|
||||||
using namespace paddle::platform;
|
|
||||||
|
|
||||||
auto& instance = DataTransformFnMap::Instance();
|
|
||||||
paddle::framework::Variable in;
|
|
||||||
paddle::framework::Variable out;
|
|
||||||
|
|
||||||
DeviceContext* ctx = new CPUDeviceContext();
|
|
||||||
auto pair0 = std::make_pair(frw::kernel0, frw::kernel1);
|
|
||||||
instance.Get(pair0)(ctx, pair0, in, &out);
|
|
||||||
ASSERT_EQ(test_value, 1);
|
|
||||||
|
|
||||||
auto pair1 = std::make_pair(frw::kernel1, frw::kernel2);
|
|
||||||
instance.Get(pair1)(ctx, pair1, in, &out);
|
|
||||||
ASSERT_EQ(test_value, 0);
|
|
||||||
|
|
||||||
auto pair3 = std::make_pair(frw::kernel0, frw::kernel2);
|
|
||||||
instance.Get(pair3)(ctx, pair3, in, &out);
|
|
||||||
ASSERT_EQ(test_value, 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(DataTransform, DataLayout) {
|
|
||||||
using namespace paddle::framework;
|
|
||||||
using namespace paddle::platform;
|
|
||||||
|
|
||||||
auto& instance = DataTransformFnMap::Instance();
|
|
||||||
Variable in;
|
|
||||||
Variable out;
|
|
||||||
Tensor* src = in.GetMutable<Tensor>();
|
|
||||||
src->mutable_data<double>(make_ddim({2, 3, 1, 2}), CPUPlace());
|
|
||||||
src->set_layout(DataLayout::kNHWC);
|
|
||||||
|
|
||||||
DeviceContext* ctx = new CPUDeviceContext();
|
|
||||||
|
|
||||||
{
|
|
||||||
auto kernel1 = GenFromBit({1, 0, 0, 0});
|
|
||||||
auto kernel2 = GenFromBit({1, 0, 1, 0});
|
|
||||||
auto pair0 = std::make_pair(kernel1, kernel2);
|
|
||||||
instance.Get(pair0)(ctx, pair0, in, &out);
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor dst = out.Get<Tensor>();
|
|
||||||
|
|
||||||
EXPECT_TRUE(dst.layout() == DataLayout::kNCHW);
|
|
||||||
EXPECT_TRUE(dst.dims() == make_ddim({2, 2, 3, 1}));
|
|
||||||
|
|
||||||
{
|
|
||||||
auto kernel1 = GenFromBit({1, 0, 1, 0});
|
|
||||||
auto kernel2 = GenFromBit({1, 0, 0, 0});
|
|
||||||
auto pair0 = std::make_pair(kernel1, kernel2);
|
|
||||||
instance.Get(pair0)(ctx, pair0, out, &in);
|
|
||||||
}
|
|
||||||
|
|
||||||
EXPECT_TRUE(src->layout() == DataLayout::kNHWC);
|
|
||||||
EXPECT_TRUE(src->dims() == make_ddim({2, 3, 1, 2}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(DataTransform, DataType) {
|
|
||||||
using namespace paddle::framework;
|
|
||||||
using namespace paddle::platform;
|
|
||||||
|
|
||||||
auto& instance = DataTransformFnMap::Instance();
|
|
||||||
DeviceContext* ctx = new CPUDeviceContext();
|
|
||||||
|
|
||||||
Variable in;
|
|
||||||
Variable out;
|
|
||||||
Tensor* src = in.GetMutable<Tensor>();
|
|
||||||
float* ptr = src->mutable_data<float>(make_ddim({2, 3}), CPUPlace());
|
|
||||||
for (int i = 0; i < 6; ++i) {
|
|
||||||
ptr[i] = i / 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
auto kernel1 = GenFromBit({0, 0, 0, 0});
|
|
||||||
auto kernel2 = GenFromBit({1, 0, 0, 0});
|
|
||||||
auto pair0 = std::make_pair(kernel1, kernel2);
|
|
||||||
instance.Get(pair0)(ctx, pair0, in, &out);
|
|
||||||
}
|
|
||||||
Tensor dst = out.Get<Tensor>();
|
|
||||||
EXPECT_TRUE(dst.data<double>() != nullptr);
|
|
||||||
}
|
|
@ -0,0 +1,99 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/framework/data_type_transform.h"
|
||||||
|
|
||||||
|
#include "paddle/framework/selected_rows.h"
|
||||||
|
#include "paddle/platform/transform.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
template <typename InType, typename OutType>
|
||||||
|
struct CastDataTypeFunctor {
|
||||||
|
HOSTDEVICE inline OutType operator()(InType in) const {
|
||||||
|
return static_cast<OutType>(in);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType>
|
||||||
|
struct CastDataType {
|
||||||
|
CastDataType(const framework::Tensor& in, framework::Tensor* out,
|
||||||
|
const platform::DeviceContext* ctx)
|
||||||
|
: in_(in), out_(out), ctx_(ctx) {}
|
||||||
|
const framework::Tensor in_;
|
||||||
|
framework::Tensor* out_;
|
||||||
|
const platform::DeviceContext* ctx_;
|
||||||
|
|
||||||
|
template <typename OutType>
|
||||||
|
void operator()() {
|
||||||
|
auto place = ctx_->GetPlace();
|
||||||
|
|
||||||
|
auto* in_begin = in_.data<InType>();
|
||||||
|
auto numel = in_.numel();
|
||||||
|
auto* in_end = in_begin + numel;
|
||||||
|
auto* out_begin = out_->mutable_data<OutType>(place);
|
||||||
|
|
||||||
|
if (platform::is_cpu_place(place)) {
|
||||||
|
platform::Transform<platform::CPUDeviceContext> trans;
|
||||||
|
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
|
||||||
|
trans(*context, in_begin, in_end, out_begin,
|
||||||
|
CastDataTypeFunctor<InType, OutType>());
|
||||||
|
} else {
|
||||||
|
// TODO(dzhwinter): enhance Copy CPU<->GPU with different data type?
|
||||||
|
PADDLE_THROW("Unsupport CPU <-> GPU!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TransDataType(const platform::DeviceContext* ctx,
|
||||||
|
const KernelTypePair& kernel_pair, const Variable& in,
|
||||||
|
Variable* out) {
|
||||||
|
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!.");
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
platform::places_are_same_class(kernel_pair.first.place_,
|
||||||
|
kernel_pair.second.place_),
|
||||||
|
"TransDataType Only Support DataType transform on same place!");
|
||||||
|
|
||||||
|
auto src = in.Get<Tensor>();
|
||||||
|
auto* dst = out->GetMutable<Tensor>();
|
||||||
|
|
||||||
|
auto dims = src.dims();
|
||||||
|
dst->Resize(dims);
|
||||||
|
auto dst_type = kernel_pair.second.data_type_;
|
||||||
|
auto src_type = kernel_pair.first.data_type_;
|
||||||
|
|
||||||
|
switch (src_type) {
|
||||||
|
case proto::DataType::FP32:
|
||||||
|
framework::VisitDataType(dst_type, CastDataType<float>(src, dst, ctx));
|
||||||
|
break;
|
||||||
|
case proto::DataType::FP64:
|
||||||
|
framework::VisitDataType(dst_type, CastDataType<double>(src, dst, ctx));
|
||||||
|
break;
|
||||||
|
case proto::DataType::INT32:
|
||||||
|
framework::VisitDataType(dst_type, CastDataType<int>(src, dst, ctx));
|
||||||
|
break;
|
||||||
|
case proto::DataType::INT64:
|
||||||
|
framework::VisitDataType(dst_type, CastDataType<int64_t>(src, dst, ctx));
|
||||||
|
break;
|
||||||
|
case proto::DataType::BOOL:
|
||||||
|
framework::VisitDataType(dst_type, CastDataType<bool>(src, dst, ctx));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
PADDLE_THROW("Not support type %d", src_type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/framework/op_kernel_type.h"
|
||||||
|
#include "paddle/framework/variable.h"
|
||||||
|
#include "paddle/platform/device_context.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
|
||||||
|
|
||||||
|
void TransDataType(const platform::DeviceContext* ctx,
|
||||||
|
const KernelTypePair& kernel_pair, const Variable& in,
|
||||||
|
Variable* out);
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -1 +1 @@
|
|||||||
grpc_library(sendrecvop_grpc SRCS recv_impl.cc send_impl.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
|
grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
|
||||||
|
@ -0,0 +1,147 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "grpc_client.h"
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
bool RPCClient::AsyncSendVariable(const std::string& ep,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
const framework::Scope& scope,
|
||||||
|
const std::string& var_name,
|
||||||
|
int64_t time_out) {
|
||||||
|
sendrecv::VariableMessage req;
|
||||||
|
auto* var = scope.FindVar(var_name);
|
||||||
|
SerializeToMessage(var_name, var, ctx, &req);
|
||||||
|
|
||||||
|
// varhandle
|
||||||
|
VarHandle var_h;
|
||||||
|
var_h.ep = ep;
|
||||||
|
var_h.scope = &scope;
|
||||||
|
var_h.name = var_name;
|
||||||
|
var_h.ctx = &ctx;
|
||||||
|
|
||||||
|
// stub context
|
||||||
|
auto ch = GetChannel(ep);
|
||||||
|
SendProcessor* s = new SendProcessor(ch);
|
||||||
|
s->Prepare(var_h, time_out);
|
||||||
|
s->response_call_back_ = NULL;
|
||||||
|
|
||||||
|
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
|
||||||
|
rpc->Finish(&s->reply_, &s->status_, (void*)s);
|
||||||
|
|
||||||
|
req_count_++;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ProcGetResponse(const VarHandle& var_h,
|
||||||
|
const sendrecv::VariableMessage& ret_msg) {
|
||||||
|
auto* outvar = var_h.scope->FindVar(var_h.name);
|
||||||
|
|
||||||
|
std::istringstream iss(ret_msg.serialized());
|
||||||
|
DeserializeFromMessage(ret_msg, *var_h.ctx, outvar);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RPCClient::AsyncGetVariable(const std::string& ep,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
const framework::Scope& scope,
|
||||||
|
const std::string& var_name,
|
||||||
|
int64_t time_out) {
|
||||||
|
sendrecv::VariableMessage req;
|
||||||
|
req.set_varname(var_name);
|
||||||
|
|
||||||
|
auto* var = scope.FindVar(var_name);
|
||||||
|
SerializeToMessage(var_name, var, ctx, &req);
|
||||||
|
|
||||||
|
// varhandle
|
||||||
|
VarHandle var_h;
|
||||||
|
var_h.ep = ep;
|
||||||
|
var_h.scope = &scope;
|
||||||
|
var_h.name = var_name;
|
||||||
|
var_h.ctx = &ctx;
|
||||||
|
|
||||||
|
// stub context
|
||||||
|
auto ch = GetChannel(ep);
|
||||||
|
GetProcessor* s = new GetProcessor(ch);
|
||||||
|
s->Prepare(var_h, time_out);
|
||||||
|
s->response_call_back_ = ProcGetResponse;
|
||||||
|
|
||||||
|
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
|
||||||
|
rpc->Finish(&s->reply_, &s->status_, (void*)s);
|
||||||
|
|
||||||
|
req_count_++;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RPCClient::wait() {
|
||||||
|
bool ok = true;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (req_count_ <= 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!Proceed()) {
|
||||||
|
LOG(ERROR) << "Get meets CompletionQueue error";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RPCClient::Proceed() {
|
||||||
|
void* tag = NULL;
|
||||||
|
bool ok = false;
|
||||||
|
|
||||||
|
// request counts.
|
||||||
|
if (!cq_.Next(&tag, &ok)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
req_count_--;
|
||||||
|
|
||||||
|
GPR_ASSERT(ok);
|
||||||
|
PADDLE_ENFORCE(tag);
|
||||||
|
|
||||||
|
// TODO(gongwb): add more retries.
|
||||||
|
ClientBase* c = static_cast<ClientBase*>(tag);
|
||||||
|
if (!c->status_.ok()) {
|
||||||
|
delete c;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
c->Process();
|
||||||
|
delete c;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
|
||||||
|
auto it = channels_.find(ep);
|
||||||
|
if (it != channels_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ch = std::shared_ptr<grpc::Channel>(
|
||||||
|
grpc::CreateChannel(ep, grpc::InsecureChannelCredentials()));
|
||||||
|
|
||||||
|
channels_[ep] = ch;
|
||||||
|
return ch;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,147 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <grpc++/grpc++.h>
|
||||||
|
#include <grpc/support/log.h>
|
||||||
|
#include <time.h>
|
||||||
|
#include <chrono>
|
||||||
|
#include <ctime>
|
||||||
|
#include <functional>
|
||||||
|
#include <iostream>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/framework/data_type.h"
|
||||||
|
#include "paddle/framework/lod_tensor.h"
|
||||||
|
#include "paddle/framework/scope.h"
|
||||||
|
#include "paddle/framework/selected_rows.h"
|
||||||
|
#include "paddle/operators/detail/sendrecvop_utils.h"
|
||||||
|
#include "paddle/operators/detail/simple_block_queue.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
struct VarHandle {
|
||||||
|
std::string ep;
|
||||||
|
const platform::DeviceContext* ctx;
|
||||||
|
const framework::Scope* scope;
|
||||||
|
std::string name;
|
||||||
|
|
||||||
|
std::string String() const {
|
||||||
|
std::ostringstream s;
|
||||||
|
s << "name:[" << name << "] ep:[" << ep << "]";
|
||||||
|
return s.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void ProcGetResponse(const VarHandle& var_h,
|
||||||
|
const sendrecv::VariableMessage& msg);
|
||||||
|
|
||||||
|
class ClientBase {
|
||||||
|
public:
|
||||||
|
explicit ClientBase(std::shared_ptr<grpc::Channel> ch) {
|
||||||
|
stub_ = sendrecv::SendRecvService::NewStub(ch);
|
||||||
|
context_ = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~ClientBase() {}
|
||||||
|
|
||||||
|
virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
|
||||||
|
context_.reset(new grpc::ClientContext());
|
||||||
|
var_h_ = var_info;
|
||||||
|
|
||||||
|
std::chrono::system_clock::time_point deadline =
|
||||||
|
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
|
||||||
|
|
||||||
|
context_->set_deadline(deadline);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void Process() = 0;
|
||||||
|
|
||||||
|
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
|
||||||
|
std::unique_ptr<grpc::ClientContext> context_;
|
||||||
|
grpc::Status status_;
|
||||||
|
VarHandle var_h_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::function<void(const VarHandle&, const sendrecv::VoidMessage&)>
|
||||||
|
RequestSendCallBack;
|
||||||
|
|
||||||
|
class SendProcessor : public ClientBase {
|
||||||
|
public:
|
||||||
|
explicit SendProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
|
||||||
|
|
||||||
|
virtual ~SendProcessor() {}
|
||||||
|
|
||||||
|
virtual void Process() {
|
||||||
|
if (response_call_back_) {
|
||||||
|
response_call_back_(var_h_, reply_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendrecv::VoidMessage reply_;
|
||||||
|
RequestSendCallBack response_call_back_ = NULL;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::function<void(const VarHandle&, const sendrecv::VariableMessage&)>
|
||||||
|
RequestGetCallBack;
|
||||||
|
|
||||||
|
class GetProcessor : public ClientBase {
|
||||||
|
public:
|
||||||
|
explicit GetProcessor(std::shared_ptr<grpc::Channel> ch) : ClientBase(ch) {}
|
||||||
|
|
||||||
|
virtual ~GetProcessor() {}
|
||||||
|
|
||||||
|
virtual void Process() {
|
||||||
|
if (response_call_back_) {
|
||||||
|
response_call_back_(var_h_, reply_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendrecv::VariableMessage reply_;
|
||||||
|
RequestGetCallBack response_call_back_ = ProcGetResponse;
|
||||||
|
};
|
||||||
|
|
||||||
|
class RPCClient {
|
||||||
|
public:
|
||||||
|
bool AsyncSendVariable(const std::string& ep,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
const framework::Scope& scope,
|
||||||
|
const std::string& var_name,
|
||||||
|
int64_t time_out = 600 * 1000);
|
||||||
|
|
||||||
|
bool AsyncGetVariable(const std::string& ep,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
const framework::Scope& scope,
|
||||||
|
const std::string& var_name,
|
||||||
|
int64_t time_out = 600 * 1000);
|
||||||
|
bool wait();
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool Proceed();
|
||||||
|
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
|
||||||
|
|
||||||
|
private:
|
||||||
|
grpc::CompletionQueue cq_;
|
||||||
|
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
|
||||||
|
int64_t req_count_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,237 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/detail/grpc_server.h"
|
||||||
|
|
||||||
|
using grpc::ServerAsyncResponseWriter;
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
enum CallStatus { PROCESS = 0, FINISH };
|
||||||
|
|
||||||
|
// reference:
|
||||||
|
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
|
||||||
|
class RequestBase {
|
||||||
|
public:
|
||||||
|
explicit RequestBase(sendrecv::SendRecvService::AsyncService* service,
|
||||||
|
grpc::ServerCompletionQueue* cq)
|
||||||
|
: service_(service), cq_(cq), status_(PROCESS) {}
|
||||||
|
virtual ~RequestBase() {}
|
||||||
|
virtual void Process() { assert(false); }
|
||||||
|
|
||||||
|
CallStatus Status() { return status_; }
|
||||||
|
void SetStatus(CallStatus status) { status_ = status; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
grpc::ServerContext ctx_;
|
||||||
|
sendrecv::SendRecvService::AsyncService* service_;
|
||||||
|
grpc::ServerCompletionQueue* cq_;
|
||||||
|
CallStatus status_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
|
||||||
|
|
||||||
|
class RequestSend final : public RequestBase {
|
||||||
|
public:
|
||||||
|
explicit RequestSend(sendrecv::SendRecvService::AsyncService* service,
|
||||||
|
grpc::ServerCompletionQueue* cq,
|
||||||
|
SimpleBlockQueue<MessageWithName>* queue)
|
||||||
|
: RequestBase(service, cq), queue_(queue), responder_(&ctx_) {
|
||||||
|
service_->RequestSendVariable(&ctx_, &request_, &responder_, cq_, cq_,
|
||||||
|
this);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~RequestSend() {}
|
||||||
|
|
||||||
|
virtual void Process() {
|
||||||
|
MessageWithName msg_with_name =
|
||||||
|
std::make_pair(request_.varname(), std::move(request_));
|
||||||
|
queue_->Push(std::move(msg_with_name));
|
||||||
|
// TODO(gongwb): check var's info.
|
||||||
|
responder_.Finish(reply_, grpc::Status::OK, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
sendrecv::VariableMessage request_;
|
||||||
|
sendrecv::VoidMessage reply_;
|
||||||
|
SimpleBlockQueue<MessageWithName>* queue_;
|
||||||
|
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class RequestGet final : public RequestBase {
|
||||||
|
public:
|
||||||
|
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
|
||||||
|
grpc::ServerCompletionQueue* cq, framework::Scope* scope)
|
||||||
|
: RequestBase(service, cq), responder_(&ctx_), scope_(scope) {
|
||||||
|
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~RequestGet() {}
|
||||||
|
|
||||||
|
virtual void Process() {
|
||||||
|
// proc request.
|
||||||
|
std::string var_name = request_.varname();
|
||||||
|
auto* var = scope_->FindVar(var_name);
|
||||||
|
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_);
|
||||||
|
// TODO(gongwb): check var's info.
|
||||||
|
responder_.Finish(reply_, grpc::Status::OK, this);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
sendrecv::VariableMessage request_;
|
||||||
|
sendrecv::VariableMessage reply_;
|
||||||
|
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
|
||||||
|
framework::Scope* scope_;
|
||||||
|
};
|
||||||
|
|
||||||
|
void AsyncGRPCServer::RunSyncUpdate() {
|
||||||
|
grpc::ServerBuilder builder;
|
||||||
|
builder.AddListeningPort(address_, grpc::InsecureServerCredentials());
|
||||||
|
builder.RegisterService(&service_);
|
||||||
|
|
||||||
|
cq_send_ = builder.AddCompletionQueue();
|
||||||
|
cq_get_ = builder.AddCompletionQueue();
|
||||||
|
server_ = builder.BuildAndStart();
|
||||||
|
LOG(INFO) << "Server listening on " << address_ << std::endl;
|
||||||
|
|
||||||
|
std::function<void()> send_register =
|
||||||
|
std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this);
|
||||||
|
std::function<void()> get_register =
|
||||||
|
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);
|
||||||
|
|
||||||
|
t_send_.reset(
|
||||||
|
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false,
|
||||||
|
cq_send_.get(), "cq_send", send_register)));
|
||||||
|
|
||||||
|
t_get_.reset(
|
||||||
|
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true,
|
||||||
|
cq_get_.get(), "cq_get", get_register)));
|
||||||
|
|
||||||
|
// wait server
|
||||||
|
server_->Wait();
|
||||||
|
t_send_->join();
|
||||||
|
t_get_->join();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::ShutdownQueue() {
|
||||||
|
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||||
|
cq_send_->Shutdown();
|
||||||
|
cq_get_->Shutdown();
|
||||||
|
is_shut_down_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This URL explains why shutdown is complicate:
|
||||||
|
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
|
||||||
|
void AsyncGRPCServer::ShutDown() {
|
||||||
|
server_->Shutdown();
|
||||||
|
ShutdownQueue();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::TryToRegisterNewSendOne() {
|
||||||
|
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||||
|
if (is_shut_down_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
RequestSend* send =
|
||||||
|
new RequestSend(&service_, cq_send_.get(), &var_recv_queue_);
|
||||||
|
VLOG(4) << "create RequestSend status:" << send->Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::TryToRegisterNewGetOne() {
|
||||||
|
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||||
|
if (is_shut_down_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_);
|
||||||
|
VLOG(4) << "create Requestget status:" << get->Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::SetFinishOrDelete(RequestBase*& last) {
|
||||||
|
std::unique_lock<std::mutex> lock(cq_mutex_);
|
||||||
|
if (is_shut_down_) {
|
||||||
|
delete last;
|
||||||
|
last = NULL;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
last->SetStatus(FINISH);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
|
||||||
|
std::string cq_name,
|
||||||
|
std::function<void()> TryToRegisterNewOne) {
|
||||||
|
TryToRegisterNewOne();
|
||||||
|
|
||||||
|
void* tag = NULL;
|
||||||
|
bool ok = false;
|
||||||
|
while (true) {
|
||||||
|
if (!cq->Next(&tag, &ok)) {
|
||||||
|
LOG(INFO) << cq_name << " get CompletionQueue shutdown!";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wait && !done_) {
|
||||||
|
Wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
RequestBase* base = (RequestBase*)tag;
|
||||||
|
if (!ok) {
|
||||||
|
VLOG(4) << cq_name << " recv no regular event";
|
||||||
|
TryToRegisterNewOne();
|
||||||
|
delete base;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (base->Status()) {
|
||||||
|
case PROCESS: {
|
||||||
|
VLOG(4) << cq_name << " status:" << base->Status();
|
||||||
|
TryToRegisterNewOne();
|
||||||
|
base->Process();
|
||||||
|
SetFinishOrDelete(base);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case FINISH: {
|
||||||
|
VLOG(4) << cq_name << " status:" << base->Status();
|
||||||
|
delete base;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: { assert(false); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::Wait() {
|
||||||
|
std::unique_lock<std::mutex> lock(this->mutex_);
|
||||||
|
condition_.wait(lock, [=] { return this->done_ == true; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::Reset() {
|
||||||
|
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||||
|
done_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncGRPCServer::Done() {
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||||
|
done_ = true;
|
||||||
|
}
|
||||||
|
condition_.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,91 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/framework/lod_tensor.h"
|
||||||
|
#include "paddle/framework/scope.h"
|
||||||
|
#include "paddle/framework/selected_rows.h"
|
||||||
|
#include "paddle/framework/var_type.h"
|
||||||
|
#include "paddle/operators/detail/simple_block_queue.h"
|
||||||
|
|
||||||
|
#include "paddle/operators/detail/send_recv.grpc.pb.h"
|
||||||
|
#include "paddle/operators/detail/send_recv.pb.h"
|
||||||
|
|
||||||
|
#include <grpc++/grpc++.h>
|
||||||
|
#include <grpc/support/log.h>
|
||||||
|
#include <thread>
|
||||||
|
#include "paddle/operators/detail/sendrecvop_utils.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
|
||||||
|
class RequestBase;
|
||||||
|
|
||||||
|
class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
|
||||||
|
public:
|
||||||
|
explicit AsyncGRPCServer(std::string address) { address_ = address; }
|
||||||
|
|
||||||
|
void RunSyncUpdate();
|
||||||
|
|
||||||
|
void Reset();
|
||||||
|
|
||||||
|
void Done();
|
||||||
|
|
||||||
|
void SetScope(framework::Scope *scope) { scope_ = scope; }
|
||||||
|
|
||||||
|
const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
|
||||||
|
|
||||||
|
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
|
||||||
|
|
||||||
|
void ShutDown();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void Wait();
|
||||||
|
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
|
||||||
|
std::string cq_name,
|
||||||
|
std::function<void()> TryToRegisterNewOne);
|
||||||
|
void TryToRegisterNewSendOne();
|
||||||
|
void TryToRegisterNewGetOne();
|
||||||
|
void SetFinishOrDelete(RequestBase *&last);
|
||||||
|
void ShutdownQueue();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mutex cq_mutex_;
|
||||||
|
volatile bool is_shut_down_ = false;
|
||||||
|
std::unique_ptr<grpc::ServerCompletionQueue> cq_send_;
|
||||||
|
std::unique_ptr<grpc::ServerCompletionQueue> cq_get_;
|
||||||
|
|
||||||
|
sendrecv::SendRecvService::AsyncService service_;
|
||||||
|
std::unique_ptr<grpc::Server> server_;
|
||||||
|
|
||||||
|
std::string address_;
|
||||||
|
framework::Scope *scope_;
|
||||||
|
// received variable from RPC, operators fetch variable from this queue.
|
||||||
|
SimpleBlockQueue<MessageWithName> var_recv_queue_;
|
||||||
|
|
||||||
|
// condition of the sub program
|
||||||
|
std::mutex mutex_;
|
||||||
|
volatile mutable bool done_;
|
||||||
|
std::condition_variable condition_;
|
||||||
|
|
||||||
|
std::unique_ptr<std::thread> t_send_;
|
||||||
|
std::unique_ptr<std::thread> t_get_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}; // namespace detail
|
||||||
|
}; // namespace operators
|
||||||
|
}; // namespace paddle
|
@ -1,65 +0,0 @@
|
|||||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License. */
|
|
||||||
|
|
||||||
#include "send_recv_impl.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace operators {
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
Status SendRecvServerImpl::SendVariable(ServerContext *context,
|
|
||||||
const VariableMessage *in_var,
|
|
||||||
VoidMessage *out_var) {
|
|
||||||
MessageWithName msg_with_name =
|
|
||||||
std::make_pair(in_var->varname(), std::move(*in_var));
|
|
||||||
var_recv_queue_.Push(std::move(msg_with_name));
|
|
||||||
return Status::OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status SendRecvServerImpl::GetVariable(ServerContext *context,
|
|
||||||
const VariableMessage *in_var,
|
|
||||||
VariableMessage *out_var) {
|
|
||||||
std::string get_var_name = in_var->varname();
|
|
||||||
auto *var = scope_->FindVar(get_var_name);
|
|
||||||
|
|
||||||
SerializeToMessage(get_var_name, var, platform::CPUDeviceContext(), out_var);
|
|
||||||
return Status::OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status SendRecvServerImpl::Wait(ServerContext *context,
|
|
||||||
const VoidMessage *in_var,
|
|
||||||
VoidMessage *out_var) {
|
|
||||||
{
|
|
||||||
std::unique_lock<std::mutex> lock(this->mutex_);
|
|
||||||
condition_.wait(lock, [=] { return this->done_ == true; });
|
|
||||||
}
|
|
||||||
return Status::OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SendRecvServerImpl::Reset() {
|
|
||||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
|
||||||
done_ = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SendRecvServerImpl::Done() {
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
|
||||||
done_ = true;
|
|
||||||
}
|
|
||||||
condition_.notify_all();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
} // namespace operators
|
|
||||||
} // namespace paddle
|
|
@ -1,67 +0,0 @@
|
|||||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License. */
|
|
||||||
|
|
||||||
#include "send_recv_impl.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace operators {
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
bool RPCClient::SendVariable(const framework::Scope& scope,
|
|
||||||
const std::string& inname) {
|
|
||||||
ClientContext context;
|
|
||||||
VariableMessage msg;
|
|
||||||
VoidMessage out_msg;
|
|
||||||
// FIXME(typhoonzero): pass device context to here.
|
|
||||||
auto ctx = platform::CPUDeviceContext();
|
|
||||||
auto* var = scope.FindVar(inname);
|
|
||||||
PADDLE_ENFORCE(var);
|
|
||||||
SerializeToMessage(inname, var, ctx, &msg);
|
|
||||||
|
|
||||||
Status status = stub_->SendVariable(&context, msg, &out_msg);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << "gRPC error: " << status.error_message();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool RPCClient::GetVariable(const framework::Scope& scope,
|
|
||||||
const std::string& outname) {
|
|
||||||
ClientContext context;
|
|
||||||
VariableMessage call_msg, ret_msg;
|
|
||||||
call_msg.set_varname(outname);
|
|
||||||
auto ctx = platform::CPUDeviceContext();
|
|
||||||
Status status = stub_->GetVariable(&context, call_msg, &ret_msg);
|
|
||||||
auto* outvar = scope.FindVar(outname);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << "gRPC error: " << status.error_message();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::istringstream iss(ret_msg.serialized());
|
|
||||||
DeserializeFromMessage(ret_msg, ctx, outvar);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void RPCClient::Wait() {
|
|
||||||
ClientContext context;
|
|
||||||
VoidMessage call_msg, ret_msg;
|
|
||||||
stub_->Wait(&context, call_msg, &ret_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
} // namespace operators
|
|
||||||
} // namespace paddle
|
|
@ -1,141 +0,0 @@
|
|||||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License. */
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "paddle/framework/lod_tensor.h"
|
|
||||||
#include "paddle/framework/scope.h"
|
|
||||||
#include "paddle/framework/selected_rows.h"
|
|
||||||
#include "paddle/framework/var_type.h"
|
|
||||||
#include "paddle/operators/detail/simple_block_queue.h"
|
|
||||||
|
|
||||||
#include "paddle/operators/detail/send_recv.grpc.pb.h"
|
|
||||||
#include "paddle/operators/detail/send_recv.pb.h"
|
|
||||||
|
|
||||||
#include <grpc++/grpc++.h>
|
|
||||||
|
|
||||||
using grpc::Channel;
|
|
||||||
using grpc::Server;
|
|
||||||
using grpc::ServerContext;
|
|
||||||
using grpc::ServerReader;
|
|
||||||
using grpc::ServerBuilder;
|
|
||||||
|
|
||||||
using grpc::ClientContext;
|
|
||||||
using grpc::ClientReader;
|
|
||||||
using grpc::ClientReaderWriter;
|
|
||||||
using grpc::ClientWriter;
|
|
||||||
using grpc::Status;
|
|
||||||
using sendrecv::SendRecvService;
|
|
||||||
using sendrecv::VariableMessage;
|
|
||||||
using sendrecv::VoidMessage;
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace operators {
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
|
|
||||||
|
|
||||||
class SendRecvServerImpl final : public SendRecvService::Service {
|
|
||||||
public:
|
|
||||||
explicit SendRecvServerImpl() {}
|
|
||||||
|
|
||||||
Status SendVariable(ServerContext *context, const VariableMessage *in_var,
|
|
||||||
VoidMessage *out_var) override;
|
|
||||||
Status GetVariable(ServerContext *context, const VariableMessage *in_var,
|
|
||||||
VariableMessage *out_var) override;
|
|
||||||
Status Wait(ServerContext *context, const VoidMessage *in_var,
|
|
||||||
VoidMessage *out_var) override;
|
|
||||||
void Reset();
|
|
||||||
void Done();
|
|
||||||
void SetScope(framework::Scope *scope) { scope_ = scope; };
|
|
||||||
|
|
||||||
const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
|
|
||||||
|
|
||||||
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
// received variable from RPC, operators fetch variable from this queue.
|
|
||||||
SimpleBlockQueue<MessageWithName> var_recv_queue_;
|
|
||||||
framework::Scope *scope_;
|
|
||||||
// condition of the sub program
|
|
||||||
std::mutex mutex_;
|
|
||||||
bool done_;
|
|
||||||
std::condition_variable condition_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// RPCClient is a class to send tensors to pserver sub-network
|
|
||||||
// using different hashing methods.
|
|
||||||
class RPCClient {
|
|
||||||
public:
|
|
||||||
RPCClient(std::shared_ptr<Channel> channel)
|
|
||||||
: stub_(SendRecvService::NewStub(channel)) {}
|
|
||||||
|
|
||||||
bool SendVariable(const framework::Scope &scope, const std::string &inname);
|
|
||||||
bool GetVariable(const framework::Scope &scope, const std::string &outname);
|
|
||||||
void Wait();
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::unique_ptr<SendRecvService::Stub> stub_;
|
|
||||||
};
|
|
||||||
|
|
||||||
inline void SerializeToMessage(const std::string &name,
|
|
||||||
const framework::Variable *var,
|
|
||||||
const platform::DeviceContext &ctx,
|
|
||||||
VariableMessage *msg) {
|
|
||||||
msg->set_varname(name);
|
|
||||||
std::ostringstream oss;
|
|
||||||
switch (framework::ToVarType(var->Type())) {
|
|
||||||
case framework::proto::VarDesc_VarType_LOD_TENSOR:
|
|
||||||
msg->set_type(sendrecv::VarType::LOD_TENSOR);
|
|
||||||
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
|
|
||||||
break;
|
|
||||||
case framework::proto::VarDesc_VarType_SELECTED_ROWS:
|
|
||||||
msg->set_type(sendrecv::VarType::SELECTED_ROWS);
|
|
||||||
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
|
|
||||||
ctx);
|
|
||||||
break;
|
|
||||||
default: {
|
|
||||||
PADDLE_THROW("Serialize does not support type: %s",
|
|
||||||
typeid(var->Type()).name());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
msg->set_serialized(oss.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void DeserializeFromMessage(const VariableMessage &msg,
|
|
||||||
const platform::DeviceContext &ctx,
|
|
||||||
framework::Variable *var) {
|
|
||||||
using namespace paddle::framework::proto;
|
|
||||||
std::istringstream iss(msg.serialized());
|
|
||||||
switch (msg.type()) {
|
|
||||||
case sendrecv::VarType::LOD_TENSOR:
|
|
||||||
DeserializeFromStream(iss, var->GetMutable<framework::LoDTensor>(), ctx);
|
|
||||||
break;
|
|
||||||
case sendrecv::VarType::SELECTED_ROWS: {
|
|
||||||
DeserializeFromStream(iss, var->GetMutable<framework::SelectedRows>(),
|
|
||||||
ctx);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default: {
|
|
||||||
PADDLE_THROW("Deserialize does not support type: %s",
|
|
||||||
typeid(var->Type()).name());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
} // namespace operators
|
|
||||||
} // namespace paddle
|
|
@ -0,0 +1,68 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/detail/sendrecvop_utils.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
void SerializeToMessage(const std::string& name, const framework::Variable* var,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
sendrecv::VariableMessage* msg) {
|
||||||
|
msg->set_varname(name);
|
||||||
|
std::ostringstream oss;
|
||||||
|
switch (framework::ToVarType(var->Type())) {
|
||||||
|
case framework::proto::VarDesc_VarType_LOD_TENSOR:
|
||||||
|
msg->set_type(sendrecv::VarType::LOD_TENSOR);
|
||||||
|
framework::SerializeToStream(oss, var->Get<framework::LoDTensor>(), ctx);
|
||||||
|
break;
|
||||||
|
case framework::proto::VarDesc_VarType_SELECTED_ROWS:
|
||||||
|
msg->set_type(sendrecv::VarType::SELECTED_ROWS);
|
||||||
|
framework::SerializeToStream(oss, var->Get<framework::SelectedRows>(),
|
||||||
|
ctx);
|
||||||
|
break;
|
||||||
|
default: {
|
||||||
|
PADDLE_THROW("Serialize does not support type: %s",
|
||||||
|
typeid(var->Type()).name());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msg->set_serialized(oss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
|
||||||
|
const platform::DeviceContext& ctx,
|
||||||
|
framework::Variable* var) {
|
||||||
|
std::istringstream iss(msg.serialized());
|
||||||
|
switch (msg.type()) {
|
||||||
|
case sendrecv::VarType::LOD_TENSOR:
|
||||||
|
DeserializeFromStream(iss, var->GetMutable<framework::LoDTensor>(), ctx);
|
||||||
|
break;
|
||||||
|
case sendrecv::VarType::SELECTED_ROWS: {
|
||||||
|
DeserializeFromStream(iss, var->GetMutable<framework::SelectedRows>(),
|
||||||
|
ctx);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
PADDLE_THROW("Deserialize does not support type: %s",
|
||||||
|
typeid(var->Type()).name());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue