|
|
|
@ -389,8 +389,8 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
|
|
|
|
|
OpHandleBase *op_handle = nullptr;
|
|
|
|
|
|
|
|
|
|
auto append_allreduce_op = [&](
|
|
|
|
|
std::vector<Scope *> &scopes,
|
|
|
|
|
std::vector<platform::Place> &places) -> OpHandleBase * {
|
|
|
|
|
const std::vector<Scope *> &scopes,
|
|
|
|
|
const std::vector<platform::Place> &places) -> OpHandleBase * {
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
|
|
|
|
@ -407,13 +407,11 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
|
|
|
|
|
op_handle = append_allreduce_op(local_scopes_, places_);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto p = places_[i];
|
|
|
|
|
std::vector<Scope *> ss{local_scopes_[i]};
|
|
|
|
|
std::vector<platform::Place> ps{p};
|
|
|
|
|
if (strategy_.enable_parallel_graph_)
|
|
|
|
|
op_handle = append_allreduce_op(ss, ps);
|
|
|
|
|
if (strategy_.enable_parallel_graph_) {
|
|
|
|
|
op_handle = append_allreduce_op({local_scopes_[i]}, {places_[i]});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SetCommunicationContext(op_handle, p);
|
|
|
|
|
SetCommunicationContext(op_handle, places_[i]);
|
|
|
|
|
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
@ -421,7 +419,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
|
|
|
|
|
|
|
|
|
|
auto var =
|
|
|
|
|
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
|
|
|
|
|
vars.size(), i, og, p);
|
|
|
|
|
vars.size(), i, og, places_[i]);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
|
}
|
|
|
|
|