|
|
|
@ -23,6 +23,7 @@ namespace details {
|
|
|
|
|
|
|
|
|
|
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
|
|
|
|
|
Scope *scope) {
|
|
|
|
|
VLOG(3) << "NewTempScopeAndInitVars";
|
|
|
|
|
Scope &local_scope = scope->NewScope();
|
|
|
|
|
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
|
|
|
|
|
&local_scope;
|
|
|
|
@ -43,12 +44,15 @@ inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
|
|
|
|
|
// get RpcContext and remote send and recv op
|
|
|
|
|
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
|
|
|
|
|
using RpcCtxMap = operators::distributed::RpcCtxMap;
|
|
|
|
|
VLOG(3) << "ProcessGraph";
|
|
|
|
|
RpcCtxMap send_varname_to_ctx;
|
|
|
|
|
RpcCtxMap recv_varname_to_ctx;
|
|
|
|
|
for (auto i = 0; i < graphs.size(); ++i) {
|
|
|
|
|
for (auto &node : graphs[i]->Nodes()) {
|
|
|
|
|
if (node->IsOp()) {
|
|
|
|
|
if (node->Op()->Type() == "send") {
|
|
|
|
|
VLOG(3) << "node name " << node->Name();
|
|
|
|
|
std::vector<ir::Node *> nodes_to_delete;
|
|
|
|
|
if (node && node->IsOp()) {
|
|
|
|
|
if (node->Name() == "send") {
|
|
|
|
|
auto send_var_name = node->Op()->Input("X")[0];
|
|
|
|
|
auto send_varnames = boost::get<std::vector<std::string>>(
|
|
|
|
|
node->Op()->GetNullableAttr("send_varnames"));
|
|
|
|
@ -61,8 +65,8 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
|
|
|
|
|
epmap, height_section);
|
|
|
|
|
VLOG(3) << "find and init an send op: "
|
|
|
|
|
<< send_varname_to_ctx[send_var_name];
|
|
|
|
|
} else if (node->Op()->Type() == "recv") {
|
|
|
|
|
auto recv_var_name = node->Op()->Input("X")[0];
|
|
|
|
|
} else if (node->Name() == "recv") {
|
|
|
|
|
auto recv_var_name = node->Op()->Output("Out")[0];
|
|
|
|
|
auto recv_varnames = boost::get<std::vector<std::string>>(
|
|
|
|
|
node->Op()->GetNullableAttr("recv_varnames"));
|
|
|
|
|
auto epmap = boost::get<std::vector<std::string>>(
|
|
|
|
@ -70,18 +74,23 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
|
|
|
|
|
recv_varname_to_ctx[recv_var_name] =
|
|
|
|
|
operators::distributed::RpcContext(recv_var_name, recv_varnames,
|
|
|
|
|
epmap, {});
|
|
|
|
|
graphs[i]->RemoveNode(node);
|
|
|
|
|
nodes_to_delete.push_back(node);
|
|
|
|
|
VLOG(3) << "find and remove an recv op: "
|
|
|
|
|
<< recv_varname_to_ctx[recv_var_name];
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "delete all recv ops";
|
|
|
|
|
for (auto *node : nodes_to_delete) {
|
|
|
|
|
graphs[i]->RemoveNode(node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// init communicator here
|
|
|
|
|
if (send_varname_to_ctx.size() > 0) {
|
|
|
|
|
VLOG(3) << "this is distribute mode, will use ";
|
|
|
|
|
VLOG(3) << "this is distribute mode, will use communicator";
|
|
|
|
|
operators::distributed::Communicator::Init(send_varname_to_ctx,
|
|
|
|
|
recv_varname_to_ctx, scope);
|
|
|
|
|
operators::distributed::Communicator::GetInstance()->Start();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|