|
|
|
@ -75,8 +75,8 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
server_thread_->join();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Run(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &dev_place) const override {
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &dev_place) const override {
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(dev_place);
|
|
|
|
|
framework::Scope &recv_scope = scope.NewScope();
|
|
|
|
@ -101,7 +101,6 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
// the gradients arrives, just add suffix 0~n and merge the gradient.
|
|
|
|
|
rpc_service_->SetCond(0);
|
|
|
|
|
size_t recv_var_cnt = 0;
|
|
|
|
|
size_t update_param_cnt = 0;
|
|
|
|
|
int batch_barrier = 0;
|
|
|
|
|
while (batch_barrier != fan_in) {
|
|
|
|
|
const detail::MessageWithName &v = rpc_service_->Get();
|
|
|
|
@ -128,29 +127,26 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
|
|
|
|
|
if (exit_flag) {
|
|
|
|
|
rpc_service_->ShutDown();
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "run optimize graph...";
|
|
|
|
|
try {
|
|
|
|
|
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
|
|
|
|
|
false /*create_local_scope*/, false /*create_vars*/);
|
|
|
|
|
} catch (std::exception &e) {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Reset the received sparse variables, the sum operator would not
|
|
|
|
|
// sum the input sparse variables which rows is empty at the next
|
|
|
|
|
// mini-batch.
|
|
|
|
|
// TOOD(Yancey1989): move the reset action into an operator, we couldn't
|
|
|
|
|
// TODO(Yancey1989): move the reset action into an operator, we couldn't
|
|
|
|
|
// have any hide logic in the operator.
|
|
|
|
|
for (auto &var : sparse_vars) {
|
|
|
|
|
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
|
|
|
|
|
}
|
|
|
|
|
rpc_service_->SetCond(1);
|
|
|
|
|
rpc_service_->WaitClientGet(update_param_cnt);
|
|
|
|
|
grads_counter_.clear();
|
|
|
|
|
// FIXME(typhoonzero): use another condition to sync wait clients get.
|
|
|
|
|
rpc_service_->WaitClientGet(ins.size());
|
|
|
|
|
sparse_vars.clear();
|
|
|
|
|
} // while(true)
|
|
|
|
|
}
|
|
|
|
@ -158,7 +154,6 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
protected:
|
|
|
|
|
std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
|
|
|
|
|
std::shared_ptr<std::thread> server_thread_;
|
|
|
|
|
mutable std::unordered_map<std::string, int> grads_counter_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|