You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
154 lines
5.6 KiB
154 lines
5.6 KiB
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
#include "paddle/fluid/operators/distributed/communicator.h"
|
|
#include "paddle/fluid/operators/distributed/distributed.h"
|
|
|
|
namespace paddle {
|
|
namespace framework {
|
|
class InferShapeContext;
|
|
class OpDesc;
|
|
class Scope;
|
|
template <typename T>
|
|
class EmptyGradOpMaker;
|
|
} // namespace framework
|
|
namespace imperative {
|
|
class OpBase;
|
|
} // namespace imperative
|
|
} // namespace paddle
|
|
|
|
namespace paddle {
|
|
namespace operators {
|
|
|
|
namespace distributed {
|
|
class RPCClient;
|
|
} // namespace distributed
|
|
|
|
class RecvOp : public framework::OperatorBase {
|
|
public:
|
|
RecvOp(const std::string &type, const framework::VariableNameMap &inputs,
|
|
const framework::VariableNameMap &outputs,
|
|
const framework::AttributeMap &attrs)
|
|
: OperatorBase(type, inputs, outputs, attrs) {}
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
const platform::Place &place) const override {
|
|
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
|
|
std::vector<std::string> varnames =
|
|
Attr<std::vector<std::string>>("varnames");
|
|
|
|
auto outs = Outputs("Out");
|
|
bool with_barrier = Attr<bool>("with_barrier");
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
auto &ctx = *pool.Get(place);
|
|
auto trainer_id = Attr<int>("trainer_id");
|
|
|
|
distributed::RPCClient *rpc_client =
|
|
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
|
|
|
|
std::vector<std::string> recv_varnames =
|
|
Attr<std::vector<std::string>>("recv_varnames");
|
|
|
|
if (recv_varnames.size() > 0) {
|
|
auto *communicator = distributed::Communicator::GetInstance();
|
|
|
|
if (communicator != nullptr) {
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
"execute startup program must before fleet.init_worker"));
|
|
}
|
|
} else {
|
|
std::vector<distributed::VarHandlePtr> rets;
|
|
if (with_barrier) {
|
|
for (size_t i = 0; i < outs.size(); i++) {
|
|
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
|
|
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
|
|
<< varname << " and with AsyncGetVar";
|
|
rets.push_back(
|
|
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
|
|
}
|
|
} else {
|
|
for (size_t i = 0; i < outs.size(); i++) {
|
|
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
|
|
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
|
|
<< varname << " and with AsyncGetVarNoBarrier";
|
|
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
|
|
varname, outs[i]));
|
|
}
|
|
}
|
|
for (size_t i = 0; i < rets.size(); i++) {
|
|
VLOG(7) << "before sync_recv " << outs[i] << "from " << epmap[i];
|
|
PADDLE_ENFORCE_NE(
|
|
rets[i]->Wait(), 0U,
|
|
platform::errors::ExecutionTimeout("internal error in RPCClient"));
|
|
VLOG(7) << "after sync_recv " << outs[i] << "from " << epmap[i];
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
public:
|
|
void Make() {
|
|
AddInput("X", "(Any) Dummy inputs, used for control dependency")
|
|
.AsDuplicable();
|
|
AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable();
|
|
AddComment(R"DOC(
|
|
Recv operator
|
|
|
|
This operator can get variables from server side.
|
|
)DOC");
|
|
AddAttr<std::vector<std::string>>("epmap",
|
|
"(string vector, default 127.0.0.1:6164)"
|
|
"Server endpoints in the order of input "
|
|
"variables for mapping")
|
|
.SetDefault({});
|
|
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
|
|
AddAttr<bool>("with_barrier",
|
|
"(bool, default True) if with_barrier=False, will use "
|
|
"AsyncGetVarNoBarrier get variable from pserver immediately")
|
|
.SetDefault(true);
|
|
AddAttr<std::vector<std::string>>(
|
|
"varnames",
|
|
"(string vector, default {}) "
|
|
"sometimes we need to put received var in another name "
|
|
"for example: we need var named 'moment_1@127.0.0.1:1001', "
|
|
"and it real name on parameter server is 'moment_1'. ")
|
|
.SetDefault({});
|
|
AddAttr<std::vector<std::string>>(
|
|
"recv_varnames",
|
|
"(vector<string>) "
|
|
"the split parameter varnames to be recved from pserver")
|
|
.SetDefault(std::vector<std::string>{});
|
|
AddAttr<int>("do_not_run", "if recv need to really run").SetDefault(0);
|
|
}
|
|
};
|
|
|
|
class RecvOpShapeInference : public framework::InferShapeBase {
|
|
public:
|
|
void operator()(framework::InferShapeContext *ctx) const override {}
|
|
};
|
|
|
|
} // namespace operators
|
|
} // namespace paddle
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
REGISTER_OPERATOR(
|
|
recv, ops::RecvOp,
|
|
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
|
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
|
|
ops::RecvOpMaker, ops::RecvOpShapeInference);
|