|
|
|
@ -37,6 +37,7 @@ class SendBarrierOp : public framework::OperatorBase {
|
|
|
|
|
void RunImpl(const framework::Scope& scope,
|
|
|
|
|
const platform::Place& place) const override {
|
|
|
|
|
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
|
|
|
|
|
bool sync_mode = Attr<bool>("sync_mode");
|
|
|
|
|
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto& ctx = *pool.Get(place);
|
|
|
|
@ -51,12 +52,13 @@ class SendBarrierOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
// need to wait before sending send_barrier message
|
|
|
|
|
PADDLE_ENFORCE(rpc_client->Wait());
|
|
|
|
|
|
|
|
|
|
for (auto& ep : eps) {
|
|
|
|
|
VLOG(3) << "send barrier, ep: " << ep;
|
|
|
|
|
rpc_client->AsyncSendBatchBarrier(ep);
|
|
|
|
|
if (sync_mode) {
|
|
|
|
|
for (auto& ep : eps) {
|
|
|
|
|
VLOG(3) << "send barrier, ep: " << ep;
|
|
|
|
|
rpc_client->AsyncSendBatchBarrier(ep);
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(rpc_client->Wait());
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(rpc_client->Wait());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -77,6 +79,7 @@ the Parameter Server would knew all variables have been sent.
|
|
|
|
|
"(string vector, default 127.0.0.1:6164)"
|
|
|
|
|
"Server endpoints to send variables to.")
|
|
|
|
|
.SetDefault({"127.0.0.1:6164"});
|
|
|
|
|
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|