commit
6332bd1ed8
@ -0,0 +1,193 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "mkldnn_activation_op.h"
|
||||
#include "paddle/fluid/operators/activation_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using paddle::framework::Tensor;
|
||||
using paddle::platform::MKLDNNDeviceContext;
|
||||
|
||||
namespace {
|
||||
template <typename T, typename ExecContext>
|
||||
void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
|
||||
const T alpha = 0, const T beta = 0) {
|
||||
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
|
||||
"It must use CPUPlace.");
|
||||
|
||||
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
||||
const auto &mkldnn_engine = dev_ctx.GetEngine();
|
||||
|
||||
// get buffers
|
||||
const auto *src = ctx.template Input<Tensor>("X");
|
||||
const auto *src_data = src->template data<T>();
|
||||
|
||||
auto *dst = ctx.template Output<Tensor>("Out");
|
||||
const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
// get memory dim
|
||||
PADDLE_ENFORCE(src->dims().size() == 4,
|
||||
"Input dim must be with 4, i.e. NCHW");
|
||||
std::vector<int> src_tz = framework::vectorize2int(src->dims());
|
||||
|
||||
// create memory description
|
||||
// TODO(kbinias-intel): support more formats
|
||||
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
||||
mkldnn::memory::format::nchw);
|
||||
|
||||
// create memory primitives
|
||||
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src_data);
|
||||
auto dst_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)dst_data);
|
||||
|
||||
auto forward_desc = mkldnn::eltwise_forward::desc(
|
||||
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
|
||||
|
||||
// save prim desc into global device context to be referred in backward path
|
||||
const std::string key = ctx.op().Output("Out");
|
||||
const std::string key_eltwise_pd = key + "@eltwise_pd";
|
||||
auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
|
||||
forward_desc, mkldnn_engine);
|
||||
dev_ctx.SetBlob(key_eltwise_pd, forward_pd);
|
||||
|
||||
auto eltwise = mkldnn::eltwise_forward(*forward_pd, src_memory, dst_memory);
|
||||
|
||||
// push primitive to stream and wait until it's executed
|
||||
std::vector<mkldnn::primitive> pipeline = {eltwise};
|
||||
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
||||
}
|
||||
|
||||
template <typename T, typename ExecContext>
|
||||
void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
|
||||
const T alpha = 0, const T beta = 0) {
|
||||
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
|
||||
const auto &mkldnn_engine = dev_ctx.GetEngine();
|
||||
|
||||
// get buffers
|
||||
const auto *x = ctx.template Input<Tensor>("X");
|
||||
const auto *src = x->template data<T>();
|
||||
|
||||
auto *dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
|
||||
const auto *diff_dst = dout->template data<T>();
|
||||
|
||||
auto *dx =
|
||||
ctx.template Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
const T *diff_src = dx->template mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
// get memory dim
|
||||
std::vector<int> src_tz = framework::vectorize2int(x->dims());
|
||||
|
||||
// create memory description
|
||||
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
|
||||
mkldnn::memory::format::nchw);
|
||||
|
||||
// create memory primitives
|
||||
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src);
|
||||
auto diff_src_memory =
|
||||
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_src);
|
||||
auto diff_dst_memory =
|
||||
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_dst);
|
||||
|
||||
auto backward_desc =
|
||||
mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta);
|
||||
|
||||
// retrieve eltwise primitive desc from device context
|
||||
const std::string key = ctx.op().Input("Out");
|
||||
const std::string key_eltwise_pd = key + "@eltwise_pd";
|
||||
const std::shared_ptr<void> forward_pd = dev_ctx.GetBlob(key_eltwise_pd);
|
||||
PADDLE_ENFORCE(forward_pd != nullptr,
|
||||
"Fail to find eltwise_pd in device context");
|
||||
auto *p_forward_pd =
|
||||
static_cast<mkldnn::eltwise_forward::primitive_desc *>(forward_pd.get());
|
||||
|
||||
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
|
||||
backward_desc, mkldnn_engine, *p_forward_pd);
|
||||
|
||||
auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, src_memory,
|
||||
diff_dst_memory, diff_src_memory);
|
||||
|
||||
// push primitive to stream and wait until it's executed
|
||||
std::vector<mkldnn::primitive> pipeline = {eltwise_bwd};
|
||||
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename T, mkldnn::algorithm algorithm>
|
||||
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
|
||||
template <typename ExecContext>
|
||||
void operator()(const ExecContext &ctx) const {
|
||||
eltwise_forward<T>(ctx, algorithm);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, mkldnn::algorithm algorithm>
|
||||
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
|
||||
template <typename ExecContext>
|
||||
void operator()(const ExecContext &ctx) const {
|
||||
eltwise_grad<T>(ctx, algorithm);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using ReluMkldnnFunctor =
|
||||
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;
|
||||
|
||||
template <typename T>
|
||||
using TanhMkldnnFunctor =
|
||||
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;
|
||||
|
||||
template <typename T>
|
||||
using SqrtMkldnnFunctor =
|
||||
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;
|
||||
|
||||
template <typename T>
|
||||
using AbsMkldnnFunctor =
|
||||
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;
|
||||
|
||||
template <typename T>
|
||||
using ReluMkldnnGradFunctor =
|
||||
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;
|
||||
|
||||
template <typename T>
|
||||
using TanhMkldnnGradFunctor =
|
||||
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;
|
||||
|
||||
template <typename T>
|
||||
using SqrtMkldnnGradFunctor =
|
||||
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;
|
||||
|
||||
template <typename T>
|
||||
using AbsMkldnnGradFunctor =
|
||||
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
|
||||
REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace, \
|
||||
ops::MKLDNNActivationKernel<ops::functor<float>>); \
|
||||
REGISTER_OP_KERNEL( \
|
||||
act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \
|
||||
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);
|
||||
|
||||
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
|
||||
__macro(relu, ReluMkldnnFunctor, ReluMkldnnGradFunctor); \
|
||||
__macro(tanh, TanhMkldnnFunctor, TanhMkldnnGradFunctor); \
|
||||
__macro(sqrt, SqrtMkldnnFunctor, SqrtMkldnnGradFunctor); \
|
||||
__macro(abs, AbsMkldnnFunctor, AbsMkldnnGradFunctor);
|
||||
|
||||
FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,111 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/detail/safe_ref.h"
|
||||
|
||||
#ifdef PADDLE_WITH_MKLDNN
|
||||
#include "paddle/fluid/platform/mkldnn_helper.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename Functor>
|
||||
class MKLDNNActivationKernel
|
||||
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
PADDLE_ENFORCE(context.Input<framework::Tensor>("X") != nullptr,
|
||||
"Cannot get input tensor X, variable name = %s",
|
||||
context.op().Input("X"));
|
||||
PADDLE_ENFORCE(context.Output<framework::Tensor>("Out") != nullptr,
|
||||
"Cannot find output tensor Out, variable name = %s",
|
||||
context.op().Output("Out"));
|
||||
Functor functor;
|
||||
|
||||
auto attrs = functor.GetAttrs();
|
||||
for (auto& attr : attrs) {
|
||||
*attr.second = context.Attr<float>(attr.first);
|
||||
}
|
||||
functor(context);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Functor>
|
||||
class MKLDNNActivationGradKernel
|
||||
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
Functor functor;
|
||||
|
||||
auto attrs = functor.GetAttrs();
|
||||
for (auto& attr : attrs) {
|
||||
*attr.second = context.Attr<float>(attr.first);
|
||||
}
|
||||
functor(context);
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
framework::OpKernelType GetKernelType(
|
||||
const framework::ExecutionContext& ctx,
|
||||
const framework::OperatorWithKernel& oper) {
|
||||
framework::LibraryType library{framework::LibraryType::kPlain};
|
||||
#ifdef PADDLE_WITH_MKLDNN
|
||||
if (library == framework::LibraryType::kPlain &&
|
||||
platform::CanMKLDNNBeUsed(ctx)) {
|
||||
library = framework::LibraryType::kMKLDNN;
|
||||
}
|
||||
#endif
|
||||
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
|
||||
ctx.GetPlace(), layout, library);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
class ActivationWithMKLDNNOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return GetKernelType(ctx, *this);
|
||||
}
|
||||
};
|
||||
|
||||
class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
|
||||
}
|
||||
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return GetKernelType(ctx, *this);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,103 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <ostream>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/framework.pb.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
#include <future>
|
||||
#include "paddle/fluid/operators/detail/grpc_client.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SendBarrierOp : public framework::OperatorBase {
|
||||
public:
|
||||
SendBarrierOp(const std::string& type,
|
||||
const framework::VariableNameMap& inputs,
|
||||
const framework::VariableNameMap& outputs,
|
||||
const framework::AttributeMap& attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& place) const override {
|
||||
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
|
||||
|
||||
auto client_var_name = Output("RPCClient");
|
||||
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
|
||||
"Can not find variable '%s' in the scope.",
|
||||
client_var_name);
|
||||
auto* client_var = scope.FindVar(client_var_name);
|
||||
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
|
||||
|
||||
// need to wait before sending send_barrier message
|
||||
PADDLE_ENFORCE(rpc_client->Wait());
|
||||
|
||||
for (auto& ep : eps) {
|
||||
VLOG(3) << "send barrier, ep: " << ep;
|
||||
rpc_client->AsyncSendBatchBarrier(ep);
|
||||
}
|
||||
PADDLE_ENFORCE(rpc_client->Wait());
|
||||
}
|
||||
};
|
||||
|
||||
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
SendBarrierOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddOutput("RPCClient",
|
||||
"(RPCClient) The RPC client object which is"
|
||||
"initialized at most once.");
|
||||
AddComment(R"DOC(
|
||||
SendBarrier operator
|
||||
|
||||
This operator will send a send barrier signal to list_and_serv op, so that
|
||||
the Parameter Server would knew all variables have been sent.
|
||||
)DOC");
|
||||
|
||||
AddAttr<std::vector<std::string>>("endpoints",
|
||||
"(string vector, default 127.0.0.1:6164)"
|
||||
"Server endpoints to send variables to.")
|
||||
.SetDefault({"127.0.0.1:6164"});
|
||||
}
|
||||
};
|
||||
|
||||
class SendBarrierOpVarTypeInference : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(const framework::OpDesc& op_desc,
|
||||
framework::BlockDesc* block) const override {
|
||||
auto out_var_name = op_desc.Output("RPCClient").front();
|
||||
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
|
||||
auto var_type = framework::proto::VarType::RAW;
|
||||
out_var.SetType(var_type);
|
||||
}
|
||||
};
|
||||
|
||||
class SendBarrierOpShapeInference : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext* ctx) const override {}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(send_barrier, ops::SendBarrierOp,
|
||||
paddle::framework::EmptyGradOpMaker, ops::SendBarrierOpMaker,
|
||||
ops::SendBarrierOpVarTypeInference,
|
||||
ops::SendBarrierOpShapeInference);
|
@ -0,0 +1,134 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <ostream>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/framework.pb.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
#include <future>
|
||||
#include "paddle/fluid/operators/detail/grpc_client.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
static bool NeedSend(const framework::Scope& scope,
|
||||
const std::string& varname) {
|
||||
auto* var = scope.FindVar(varname);
|
||||
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
|
||||
varname);
|
||||
if (var->IsType<framework::LoDTensor>()) {
|
||||
return var->Get<framework::LoDTensor>().IsInitialized();
|
||||
} else if (var->IsType<framework::SelectedRows>()) {
|
||||
return var->Get<framework::SelectedRows>().rows().size() > 0UL;
|
||||
} else {
|
||||
PADDLE_THROW(
|
||||
"Variable type in send side should be in "
|
||||
"[LodTensor, SelectedRows]");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
class SendVarsOp : public framework::OperatorBase {
|
||||
public:
|
||||
SendVarsOp(const std::string& type, const framework::VariableNameMap& inputs,
|
||||
const framework::VariableNameMap& outputs,
|
||||
const framework::AttributeMap& attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& place) const override {
|
||||
auto ins = Inputs("X");
|
||||
|
||||
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
|
||||
int sync_send = Attr<int>("sync_sent");
|
||||
|
||||
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
||||
auto& ctx = *pool.Get(place);
|
||||
|
||||
auto client_var_name = Output("RPCClient");
|
||||
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
|
||||
"Can not find variable '%s' in the scope.",
|
||||
client_var_name);
|
||||
auto* client_var = scope.FindVar(client_var_name);
|
||||
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
|
||||
|
||||
for (size_t i = 0; i < ins.size(); i++) {
|
||||
if (NeedSend(scope, ins[i])) {
|
||||
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
|
||||
// TODO(Yancey1989): we need to use an IO threadpool which has
|
||||
// a larger number of threads than the computing threadpool.
|
||||
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
|
||||
} else {
|
||||
VLOG(3) << "don't send no-initialied variable: " << ins[i];
|
||||
}
|
||||
}
|
||||
if (sync_send) {
|
||||
rpc_client->Wait();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
SendVarsOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
|
||||
.AsDuplicable();
|
||||
AddOutput("RPCClient",
|
||||
"(RPCClient) The RPC client object which will be"
|
||||
"initialized at most once.");
|
||||
AddComment(R"DOC(
|
||||
Send operator
|
||||
|
||||
This operator will send variables to listen_and_serve op at the parameter server.
|
||||
)DOC");
|
||||
AddAttr<int>("ync_send",
|
||||
"(int, default 0)"
|
||||
"sync send or async send.")
|
||||
.SetDefault(0);
|
||||
AddAttr<std::vector<std::string>>("epmap",
|
||||
"(string vector, default 127.0.0.1:6164)"
|
||||
"Server endpoints in the order of input "
|
||||
"variables for mapping")
|
||||
.SetDefault({"127.0.0.1:6164"});
|
||||
}
|
||||
};
|
||||
|
||||
class SendVarsOpVarTypeInference : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(const framework::OpDesc& op_desc,
|
||||
framework::BlockDesc* block) const override {
|
||||
auto out_var_name = op_desc.Output("RPCClient").front();
|
||||
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
|
||||
auto var_type = framework::proto::VarType::RAW;
|
||||
out_var.SetType(var_type);
|
||||
}
|
||||
};
|
||||
|
||||
class SendVarsOpShapeInference : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext* ctx) const override {}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(send_vars, ops::SendVarsOp,
|
||||
paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker,
|
||||
ops::SendVarsOpVarTypeInference,
|
||||
ops::SendVarsOpShapeInference);
|
Loading…
Reference in new issue