|
|
|
@ -30,7 +30,6 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(
|
|
|
|
|
auto &g = graphs.back();
|
|
|
|
|
g->Set(kGraphVars, new GraphVars(1UL));
|
|
|
|
|
g->Set(kGraphDepVars, new GraphDepVars);
|
|
|
|
|
g->Set(kGraphOps, new GraphOps);
|
|
|
|
|
}
|
|
|
|
|
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
|
|
|
|
|
|
|
|
|
@ -38,9 +37,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(
|
|
|
|
|
auto &dev_ctx = op->DeviceContext();
|
|
|
|
|
auto &p = dev_ctx.begin()->first;
|
|
|
|
|
int dev_id = boost::get<platform::CUDAPlace>(p).device;
|
|
|
|
|
auto &dev_ops = graphs[dev_id]->Get<GraphOps>(kGraphOps);
|
|
|
|
|
auto &dev_dummys = graphs[dev_id]->Get<GraphDepVars>(kGraphDepVars);
|
|
|
|
|
dev_ops.emplace_back(op);
|
|
|
|
|
graphs[dev_id]->AddNode(graph->RemoveNode(op->Node()).release());
|
|
|
|
|
|
|
|
|
|
for (auto &var : op->Inputs()) {
|
|
|
|
|