|
|
|
@ -75,13 +75,6 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
server_thread_->join();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string GetGradVarNameForTrainer(const std::string &varname) const {
|
|
|
|
|
if (grads_counter_.find(varname) == grads_counter_.end()) {
|
|
|
|
|
grads_counter_[varname] = 0;
|
|
|
|
|
}
|
|
|
|
|
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &dev_place) const override {
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
@ -91,8 +84,7 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
// FIXME(Yancey1989): initialize rpc server with lazy mode.
|
|
|
|
|
rpc_service_->SetScope(&recv_scope);
|
|
|
|
|
rpc_service_->SetDevCtx(&dev_ctx);
|
|
|
|
|
auto param_list = Attr<std::vector<std::string>>("ParamList");
|
|
|
|
|
auto grad_list = Attr<std::vector<std::string>>("GradList");
|
|
|
|
|
auto ins = Inputs("X");
|
|
|
|
|
auto fan_in = Attr<int>("Fanin");
|
|
|
|
|
|
|
|
|
|
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
|
|
|
|
@ -109,40 +101,24 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
// the gradients arrives, just add suffix 0~n and merge the gradient.
|
|
|
|
|
rpc_service_->SetCond(0);
|
|
|
|
|
size_t recv_var_cnt = 0;
|
|
|
|
|
size_t update_param_cnt = 0;
|
|
|
|
|
int batch_barrier = 0;
|
|
|
|
|
while (batch_barrier != fan_in) {
|
|
|
|
|
const detail::MessageWithName &v = rpc_service_->Get();
|
|
|
|
|
auto grad_var_name = v.first;
|
|
|
|
|
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
|
|
|
|
|
auto recv_var_name = v.first;
|
|
|
|
|
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
|
|
|
|
|
LOG(INFO) << "received terminate message and exit";
|
|
|
|
|
exit_flag = true;
|
|
|
|
|
break;
|
|
|
|
|
} else if (grad_var_name == BATCH_BARRIER_MESSAGE) {
|
|
|
|
|
} else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
|
|
|
|
|
VLOG(3) << "recv batch barrier message";
|
|
|
|
|
batch_barrier++;
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
// receive a variable
|
|
|
|
|
VLOG(3) << "received grad: " << recv_var_name;
|
|
|
|
|
recv_var_cnt++;
|
|
|
|
|
auto it =
|
|
|
|
|
std::find(grad_list.begin(), grad_list.end(), grad_var_name);
|
|
|
|
|
std::string param_var_name;
|
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
|
update_param_cnt++;
|
|
|
|
|
VLOG(3) << "received grad: " << grad_var_name
|
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << "received variable: " << grad_var_name
|
|
|
|
|
<< " no need to update param";
|
|
|
|
|
}
|
|
|
|
|
if (fan_in > 1 && !param_var_name.empty()) {
|
|
|
|
|
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
|
|
|
|
|
}
|
|
|
|
|
auto *var = recv_scope.FindVar(grad_var_name);
|
|
|
|
|
auto *var = recv_scope.FindVar(recv_var_name);
|
|
|
|
|
if (var == nullptr) {
|
|
|
|
|
LOG(ERROR) << "Can not find server side var: " << grad_var_name;
|
|
|
|
|
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
|
|
|
|
|
PADDLE_THROW("Can not find server side var");
|
|
|
|
|
}
|
|
|
|
|
detail::DeserializeFromMessage(v.second, dev_ctx, var);
|
|
|
|
@ -151,29 +127,26 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
|
|
|
|
|
if (exit_flag) {
|
|
|
|
|
rpc_service_->ShutDown();
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "run optimize graph...";
|
|
|
|
|
try {
|
|
|
|
|
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
|
|
|
|
|
false /*create_local_scope*/, false /*create_vars*/);
|
|
|
|
|
} catch (std::exception &e) {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Reset the received sparse variables, the sum operator would not
|
|
|
|
|
// sum the input sparse variables which rows is empty at the next
|
|
|
|
|
// mini-batch.
|
|
|
|
|
// TOOD(Yancey1989): move the reset action into an operator, we couldn't
|
|
|
|
|
// TODO(Yancey1989): move the reset action into an operator, we couldn't
|
|
|
|
|
// have any hide logic in the operator.
|
|
|
|
|
for (auto &var : sparse_vars) {
|
|
|
|
|
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
|
|
|
|
|
}
|
|
|
|
|
rpc_service_->SetCond(1);
|
|
|
|
|
rpc_service_->WaitClientGet(update_param_cnt);
|
|
|
|
|
grads_counter_.clear();
|
|
|
|
|
// FIXME(typhoonzero): use another condition to sync wait clients get.
|
|
|
|
|
rpc_service_->WaitClientGet(ins.size());
|
|
|
|
|
sparse_vars.clear();
|
|
|
|
|
} // while(true)
|
|
|
|
|
}
|
|
|
|
@ -181,13 +154,13 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
protected:
|
|
|
|
|
std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
|
|
|
|
|
std::shared_ptr<std::thread> server_thread_;
|
|
|
|
|
mutable std::unordered_map<std::string, int> grads_counter_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
ListenAndServ operator
|
|
|
|
|
|
|
|
|
@ -201,16 +174,7 @@ from send_op and send back variables to recv_op.
|
|
|
|
|
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
|
|
|
|
|
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
|
|
|
|
|
"BlockID to run on server side.");
|
|
|
|
|
AddAttr<std::vector<std::string>>(
|
|
|
|
|
"ParamList", "type list of string",
|
|
|
|
|
"grad->param name mapping to find which parameters to optimize.")
|
|
|
|
|
.SetDefault({});
|
|
|
|
|
AddAttr<std::vector<std::string>>(
|
|
|
|
|
"GradList", "type list of string",
|
|
|
|
|
"grad->param name mapping to find which parameters to optimize.")
|
|
|
|
|
.SetDefault({});
|
|
|
|
|
AddAttr<int>("Fanin", "type int",
|
|
|
|
|
"Number of trainers in the current cluster job")
|
|
|
|
|
AddAttr<int>("Fanin", "How many clients send to this server.")
|
|
|
|
|
.SetDefault(1);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|