|
|
|
@ -66,37 +66,25 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
const platform::DeviceContext &dev_ctx) const override {
|
|
|
|
|
// FIXME(typhoonzero): no new scopes for every run.
|
|
|
|
|
framework::Scope &recv_scope = scope.NewScope();
|
|
|
|
|
// blocking get one var from client.
|
|
|
|
|
const detail::TensorWithName &v = rpc_service_->Get();
|
|
|
|
|
auto grad_var_name = v.first;
|
|
|
|
|
|
|
|
|
|
auto param_list = Attr<std::vector<std::string>>("ParamList");
|
|
|
|
|
auto grad_list = Attr<std::vector<std::string>>("GradList");
|
|
|
|
|
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
|
|
|
|
|
std::string param_var_name;
|
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
|
size_t param_count = param_list.size();
|
|
|
|
|
for (size_t i = 0; i < param_count; ++i) {
|
|
|
|
|
// blocking get one var from client.
|
|
|
|
|
const detail::TensorWithName &v = rpc_service_->Get();
|
|
|
|
|
auto grad_var_name = v.first;
|
|
|
|
|
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
|
|
|
|
|
std::string param_var_name;
|
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << "recved grad: " << grad_var_name
|
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
|
auto *var = recv_scope.Var(grad_var_name);
|
|
|
|
|
auto *tensor = var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
// FIXME(typhoonzero): do not copy
|
|
|
|
|
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
|
|
|
|
|
}
|
|
|
|
|
// find input by "grad_var_name"
|
|
|
|
|
// auto inputs = Inputs("RX");
|
|
|
|
|
|
|
|
|
|
// FIXME(typhoonzero): Find the parameter name from input grad name
|
|
|
|
|
// rename X -> Param
|
|
|
|
|
// rename RX -> Grad
|
|
|
|
|
|
|
|
|
|
LOG(ERROR) << "recved grad: " << grad_var_name
|
|
|
|
|
<< " param: " << param_var_name;
|
|
|
|
|
auto *var = recv_scope.Var(grad_var_name);
|
|
|
|
|
auto *tensor = var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
// Param is in parent scope, put it in current scope.
|
|
|
|
|
auto *param_var = recv_scope.FindVar(param_var_name);
|
|
|
|
|
auto param_scope = recv_scope.FindScope(param_var);
|
|
|
|
|
param_scope->Rename(param_var_name, "Param");
|
|
|
|
|
recv_scope.Rename(grad_var_name, "Grad");
|
|
|
|
|
|
|
|
|
|
// FIXME(typhoonzero): do not copy
|
|
|
|
|
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
|
|
|
|
|
|
|
|
|
|
std::string program_str = Attr<std::string>("OptimizeProgram");
|
|
|
|
|
framework::ProgramDesc program_desc;
|
|
|
|
@ -104,17 +92,20 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
framework::ProgramDescBind program(program_desc);
|
|
|
|
|
framework::Executor executor(dev_ctx);
|
|
|
|
|
// Run sub graph to get optimized tensor
|
|
|
|
|
executor.Run(program, &recv_scope, 0, /*global_block*/
|
|
|
|
|
false /*create_local_scope*/);
|
|
|
|
|
|
|
|
|
|
auto *out_var = recv_scope.FindVar("ParamOut");
|
|
|
|
|
detail::TensorWithName out;
|
|
|
|
|
out.first = param_var_name;
|
|
|
|
|
out.second = out_var->Get<framework::LoDTensor>();
|
|
|
|
|
rpc_service_->Push(out);
|
|
|
|
|
// rename back the params
|
|
|
|
|
param_scope.Rename("Param", param_var_name);
|
|
|
|
|
recv_scope.Rename("Grad", grad_var_name);
|
|
|
|
|
try {
|
|
|
|
|
executor.Run(program, &recv_scope, 0, /*global_block*/
|
|
|
|
|
false /*create_local_scope*/, false /*create_vars*/);
|
|
|
|
|
} catch (std::exception &e) {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < param_count; ++i) {
|
|
|
|
|
auto *out_var = recv_scope.FindVar(param_list[i]);
|
|
|
|
|
detail::TensorWithName out;
|
|
|
|
|
out.first = param_list[i];
|
|
|
|
|
out.second = out_var->Get<framework::LoDTensor>();
|
|
|
|
|
rpc_service_->Push(out);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|