|
|
|
@ -59,6 +59,8 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
|
|
|
|
|
send_varname_to_ctx[send_var_name] =
|
|
|
|
|
operators::distributed::RpcContext(send_var_name, send_varnames,
|
|
|
|
|
epmap, height_section);
|
|
|
|
|
VLOG(3) << "find and init an send op: "
|
|
|
|
|
<< send_varname_to_ctx[send_var_name];
|
|
|
|
|
} else if (node->Op()->Type() == "recv") {
|
|
|
|
|
auto recv_var_name = node->Op()->Input("X")[0];
|
|
|
|
|
auto recv_varnames = boost::get<std::vector<std::string>>(
|
|
|
|
@ -68,13 +70,19 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
|
|
|
|
|
recv_varname_to_ctx[recv_var_name] =
|
|
|
|
|
operators::distributed::RpcContext(recv_var_name, recv_varnames,
|
|
|
|
|
epmap, {});
|
|
|
|
|
graphs[i]->RemoveNode(node);
|
|
|
|
|
VLOG(3) << "find and remove an recv op: "
|
|
|
|
|
<< recv_varname_to_ctx[recv_var_name];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// init communicator here
|
|
|
|
|
operators::distributed::Communicator::Init(send_varname_to_ctx,
|
|
|
|
|
recv_varname_to_ctx, scope);
|
|
|
|
|
if (send_varname_to_ctx.size() > 0) {
|
|
|
|
|
VLOG(3) << "this is distribute mode, will use ";
|
|
|
|
|
operators::distributed::Communicator::Init(send_varname_to_ctx,
|
|
|
|
|
recv_varname_to_ctx, scope);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
|
|
|
|
@ -110,6 +118,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
|
|
|
|
|
for (auto *scope : local_scopes_) {
|
|
|
|
|
NewTempScopeAndInitVars(var_infos_, scope);
|
|
|
|
|
}
|
|
|
|
|
ProcessGraph(graphs_, local_scopes_[0]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() {
|
|
|
|
|