polish code test=develop

revert-15774-anakin_subgraph_engine
Yancey1989 6 years ago
parent 0f8bd73cc9
commit d5090c892d

@ -34,7 +34,7 @@ namespace details {
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) { static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
// Should fix the allreduce op order if scheduling // Should fix the allreduce op order if scheduling
// them in multiple threads or processes to avoid hang. // them in multiple threads or processes to avoid hang.
// NOTE: ParallelExecutor would execute this pass on each graph, so // NOTE: ParallelGraph would execute this pass on each graph, so
// don't need to append it here. // don't need to append it here.
return (!strategy.enable_sequential_execution_ && return (!strategy.enable_sequential_execution_ &&
strategy.num_trainers_ > 1) && strategy.num_trainers_ > 1) &&

@ -389,8 +389,8 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
OpHandleBase *op_handle = nullptr; OpHandleBase *op_handle = nullptr;
auto append_allreduce_op = [&]( auto append_allreduce_op = [&](
std::vector<Scope *> &scopes, const std::vector<Scope *> &scopes,
std::vector<platform::Place> &places) -> OpHandleBase * { const std::vector<platform::Place> &places) -> OpHandleBase * {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
@ -407,13 +407,11 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
op_handle = append_allreduce_op(local_scopes_, places_); op_handle = append_allreduce_op(local_scopes_, places_);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto p = places_[i]; if (strategy_.enable_parallel_graph_) {
std::vector<Scope *> ss{local_scopes_[i]}; op_handle = append_allreduce_op({local_scopes_[i]}, {places_[i]});
std::vector<platform::Place> ps{p}; }
if (strategy_.enable_parallel_graph_)
op_handle = append_allreduce_op(ss, ps);
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, places_[i]);
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
@ -421,7 +419,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
auto var = auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), i, og, p); vars.size(), i, og, places_[i]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }

@ -32,8 +32,9 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(
g->Set(kGraphDepVars, new GraphDepVars); g->Set(kGraphDepVars, new GraphDepVars);
g->Set(kGraphOps, new GraphOps); g->Set(kGraphOps, new GraphOps);
} }
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : graph->Get<GraphOps>(kGraphOps)) { for (auto &op : op_handles) {
auto &dev_ctx = op->DeviceContext(); auto &dev_ctx = op->DeviceContext();
auto &p = dev_ctx.begin()->first; auto &p = dev_ctx.begin()->first;
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;

Loading…
Cancel
Save