|
|
|
@ -44,7 +44,7 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
|
|
|
|
|
std::vector<std::string> varnames =
|
|
|
|
|
Attr<std::vector<std::string>>("varnames");
|
|
|
|
|
int sync_mode = Attr<int>("sync_mode");
|
|
|
|
|
|
|
|
|
|
auto outs = Outputs("Out");
|
|
|
|
|
bool with_barrier = Attr<bool>("with_barrier");
|
|
|
|
|
|
|
|
|
@ -64,8 +64,8 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
trainer_id);
|
|
|
|
|
recv_functor(rpc_ctx, scope);
|
|
|
|
|
} else {
|
|
|
|
|
std::vector<distributed::VarHandlePtr> rets;
|
|
|
|
|
if (with_barrier) {
|
|
|
|
|
std::vector<distributed::VarHandlePtr> rets;
|
|
|
|
|
for (size_t i = 0; i < outs.size(); i++) {
|
|
|
|
|
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
|
|
|
|
|
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
|
|
|
|
@ -73,13 +73,7 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
rets.push_back(
|
|
|
|
|
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
|
|
|
|
|
}
|
|
|
|
|
if (sync_mode) {
|
|
|
|
|
for (size_t i = 0; i < rets.size(); i++) {
|
|
|
|
|
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
std::vector<distributed::VarHandlePtr> rets;
|
|
|
|
|
for (size_t i = 0; i < outs.size(); i++) {
|
|
|
|
|
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
|
|
|
|
|
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
|
|
|
|
@ -87,9 +81,11 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
|
|
|
|
|
varname, outs[i]));
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < rets.size(); i++) {
|
|
|
|
|
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < rets.size(); i++) {
|
|
|
|
|
VLOG(7) << "before sync_recv " << outs[i] << "from " << epmap[i];
|
|
|
|
|
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
|
|
|
|
|
VLOG(7) << "after sync_recv " << outs[i] << "from " << epmap[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -112,10 +108,6 @@ This operator can get variables from server side.
|
|
|
|
|
"variables for mapping")
|
|
|
|
|
.SetDefault({});
|
|
|
|
|
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
|
|
|
|
|
AddAttr<int>("sync_mode",
|
|
|
|
|
"(int, default 0)"
|
|
|
|
|
"sync recv or async recv.")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
AddAttr<bool>("with_barrier",
|
|
|
|
|
"(bool, default True) if with_barrier=False, will use "
|
|
|
|
|
"AsyncGetVarNoBarrier get variable from pserver immediately")
|
|
|
|
|