|
|
|
@ -106,6 +106,7 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
// the gradients arrives, just add suffix 0~n and merge the gradient.
|
|
|
|
|
rpc_service_->SetCond(0);
|
|
|
|
|
size_t recv_var_cnt = 0;
|
|
|
|
|
size_t update_param_cnt = 0;
|
|
|
|
|
int batch_barrier = 0;
|
|
|
|
|
while (batch_barrier != fan_in) {
|
|
|
|
|
const detail::MessageWithName &v = rpc_service_->Get();
|
|
|
|
@ -126,13 +127,14 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
std::string param_var_name;
|
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
|
update_param_cnt++;
|
|
|
|
|
VLOG(3) << "received grad: " << grad_var_name
|
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "grad has no paired param:" << grad_var_name;
|
|
|
|
|
VLOG(3) << "received variable: " << grad_var_name
|
|
|
|
|
<< " no need to update param";
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "received grad: " << grad_var_name
|
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
|
|
|
|
|
|
if (fan_in > 1) {
|
|
|
|
|
if (fan_in > 1 && !param_var_name.empty()) {
|
|
|
|
|
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
|
|
|
|
|
}
|
|
|
|
|
auto *var = recv_scope.FindVar(grad_var_name);
|
|
|
|
@ -144,11 +146,10 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
|
|
|
|
|
// TODO(Yancey1989): merge SelectedRows variables here
|
|
|
|
|
if (exit_flag) {
|
|
|
|
|
rpc_service_->ShutDown();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "run optimize graph...";
|
|
|
|
|
try {
|
|
|
|
|
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
|
|
|
|
|
false /*create_local_scope*/, false /*create_vars*/);
|
|
|
|
@ -156,7 +157,7 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
|
rpc_service_->SetCond(1);
|
|
|
|
|
rpc_service_->WaitClientGet(recv_var_cnt);
|
|
|
|
|
rpc_service_->WaitClientGet(update_param_cnt);
|
|
|
|
|
grads_counter_.clear();
|
|
|
|
|
} // while(true)
|
|
|
|
|
}
|
|
|
|
|