|
|
|
@ -69,43 +69,47 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
auto param_list = Attr<std::vector<std::string>>("ParamList");
|
|
|
|
|
auto grad_list = Attr<std::vector<std::string>>("GradList");
|
|
|
|
|
size_t param_count = param_list.size();
|
|
|
|
|
for (size_t i = 0; i < param_count; ++i) {
|
|
|
|
|
// blocking get one var from client.
|
|
|
|
|
const detail::TensorWithName &v = rpc_service_->Get();
|
|
|
|
|
auto grad_var_name = v.first;
|
|
|
|
|
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
|
|
|
|
|
std::string param_var_name;
|
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
|
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
|
|
|
|
|
while (true) {
|
|
|
|
|
// TODO(typhoonzero): get from multiple trainers.
|
|
|
|
|
for (size_t i = 0; i < param_count; ++i) {
|
|
|
|
|
// blocking get one var from client.
|
|
|
|
|
const detail::TensorWithName &v = rpc_service_->Get();
|
|
|
|
|
auto grad_var_name = v.first;
|
|
|
|
|
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
|
|
|
|
|
std::string param_var_name;
|
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << "recved grad: " << grad_var_name
|
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
|
auto *var = recv_scope.Var(grad_var_name);
|
|
|
|
|
auto *tensor = var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
// FIXME(typhoonzero): do not copy
|
|
|
|
|
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << "recved grad: " << grad_var_name
|
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
|
auto *var = recv_scope.Var(grad_var_name);
|
|
|
|
|
auto *tensor = var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
// FIXME(typhoonzero): do not copy
|
|
|
|
|
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string program_str = Attr<std::string>("OptimizeProgram");
|
|
|
|
|
framework::ProgramDesc program_desc;
|
|
|
|
|
program_desc.ParseFromString(program_str);
|
|
|
|
|
framework::ProgramDescBind program(program_desc);
|
|
|
|
|
framework::Executor executor(dev_ctx);
|
|
|
|
|
// Run sub graph to get optimized tensor
|
|
|
|
|
try {
|
|
|
|
|
executor.Run(program, &recv_scope, 0, /*global_block*/
|
|
|
|
|
false /*create_local_scope*/, false /*create_vars*/);
|
|
|
|
|
} catch (std::exception &e) {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
|
std::string program_str = Attr<std::string>("OptimizeProgram");
|
|
|
|
|
framework::ProgramDesc program_desc;
|
|
|
|
|
program_desc.ParseFromString(program_str);
|
|
|
|
|
framework::ProgramDescBind program(program_desc);
|
|
|
|
|
framework::Executor executor(dev_ctx);
|
|
|
|
|
// Run sub graph to get optimized tensor
|
|
|
|
|
try {
|
|
|
|
|
executor.Run(program, &recv_scope, 0, /*global_block*/
|
|
|
|
|
false /*create_local_scope*/, false /*create_vars*/);
|
|
|
|
|
} catch (std::exception &e) {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < param_count; ++i) {
|
|
|
|
|
auto *out_var = recv_scope.FindVar(param_list[i]);
|
|
|
|
|
detail::TensorWithName out;
|
|
|
|
|
out.first = param_list[i];
|
|
|
|
|
out.second = out_var->Get<framework::LoDTensor>();
|
|
|
|
|
rpc_service_->Push(out);
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < param_count; ++i) {
|
|
|
|
|
auto *out_var = recv_scope.FindVar(param_list[i]);
|
|
|
|
|
detail::TensorWithName out;
|
|
|
|
|
out.first = param_list[i];
|
|
|
|
|
out.second = out_var->Get<framework::LoDTensor>();
|
|
|
|
|
rpc_service_->Push(out);
|
|
|
|
|
}
|
|
|
|
|
} // while(true)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|