|
|
@ -80,7 +80,7 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
auto grad_list = Attr<std::vector<std::string>>("GradList");
|
|
|
|
auto grad_list = Attr<std::vector<std::string>>("GradList");
|
|
|
|
auto trainer_count = Attr<int>("Trainers");
|
|
|
|
auto trainer_count = Attr<int>("Trainers");
|
|
|
|
size_t param_count = param_list.size();
|
|
|
|
size_t param_count = param_list.size();
|
|
|
|
rpc_service_->Start();
|
|
|
|
rpc_service_->Reset();
|
|
|
|
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
|
|
|
|
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
|
|
|
|
while (true) {
|
|
|
|
while (true) {
|
|
|
|
// Get from multiple trainers, we don't care about order in which
|
|
|
|
// Get from multiple trainers, we don't care about order in which
|
|
|
@ -93,6 +93,8 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
std::string param_var_name;
|
|
|
|
std::string param_var_name;
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
LOG(ERROR) << "grad have no paired param found!";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
VLOG(3) << "recved grad: " << grad_var_name
|
|
|
|
VLOG(3) << "recved grad: " << grad_var_name
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
@ -112,7 +114,7 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
// FIXME(typhoonzero): do not copy
|
|
|
|
// FIXME(typhoonzero): do not copy
|
|
|
|
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
|
|
|
|
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rpc_service_->Start();
|
|
|
|
rpc_service_->Reset();
|
|
|
|
|
|
|
|
|
|
|
|
std::string program_str = Attr<std::string>("OptimizeProgram");
|
|
|
|
std::string program_str = Attr<std::string>("OptimizeProgram");
|
|
|
|
framework::ProgramDesc program_desc;
|
|
|
|
framework::ProgramDesc program_desc;
|
|
|
|