|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/send_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/scope.h"
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
@ -34,26 +35,46 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
|
const std::vector<Scope *> &local_scopes,
|
|
|
|
|
platform::NCCLContextMap *nccl_ctxs)
|
|
|
|
|
platform::NCCLContextMap *nccl_ctxs, bool distributed)
|
|
|
|
|
: loss_var_name_(loss_var_name),
|
|
|
|
|
places_(places),
|
|
|
|
|
local_scopes_(local_scopes),
|
|
|
|
|
distributed_(distributed),
|
|
|
|
|
nccl_ctxs_(nccl_ctxs) {
|
|
|
|
|
#else
|
|
|
|
|
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
|
const std::vector<Scope *> &local_scopes)
|
|
|
|
|
const std::vector<Scope *> &local_scopes, bool distributed)
|
|
|
|
|
: loss_var_name_(loss_var_name),
|
|
|
|
|
places_(places),
|
|
|
|
|
local_scopes_(local_scopes) {
|
|
|
|
|
local_scopes_(local_scopes),
|
|
|
|
|
distributed_(distributed) {
|
|
|
|
|
#endif
|
|
|
|
|
for (auto &p : params) {
|
|
|
|
|
grad_names_.insert(GradVarName(p));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
|
|
|
|
|
const platform::Place &p,
|
|
|
|
|
const size_t &i) const {
|
|
|
|
|
auto *op_handle = result->ops_.back().get();
|
|
|
|
|
|
|
|
|
|
auto var_names = op->InputArgumentNames();
|
|
|
|
|
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
|
|
|
|
|
op_handle->AddInput(var);
|
|
|
|
|
}
|
|
|
|
|
var_names = op->OutputArgumentNames();
|
|
|
|
|
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
CreateOpOutput(result, op_handle, each_var_name, p, i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
const ProgramDesc &program) const {
|
|
|
|
|
auto graph = new SSAGraph();
|
|
|
|
@ -72,6 +93,17 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// append send op if program is distributed trainer main program.
|
|
|
|
|
// always use the first device
|
|
|
|
|
if (is_forwarding && distributed_ && op->Type() == "send") {
|
|
|
|
|
auto &p = places_[0];
|
|
|
|
|
auto *s = local_scopes_[0];
|
|
|
|
|
size_t i = 0;
|
|
|
|
|
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
|
|
|
|
|
CreateOpHandleIOs(&result, op, p, i);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
auto *s = local_scopes_[i];
|
|
|
|
@ -81,18 +113,19 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
|
|
|
|
|
auto var_names = op->InputArgumentNames();
|
|
|
|
|
CreateOpHandleIOs(&result, op, p, i);
|
|
|
|
|
// auto var_names = op->InputArgumentNames();
|
|
|
|
|
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
VarHandle *var =
|
|
|
|
|
CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
|
|
|
|
|
op_handle->AddInput(var);
|
|
|
|
|
}
|
|
|
|
|
var_names = op->OutputArgumentNames();
|
|
|
|
|
// for (auto &each_var_name : var_names) {
|
|
|
|
|
// VarHandle *var =
|
|
|
|
|
// CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
|
|
|
|
|
// op_handle->AddInput(var);
|
|
|
|
|
// }
|
|
|
|
|
auto var_names = op->OutputArgumentNames();
|
|
|
|
|
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
CreateOpOutput(&result, op_handle, each_var_name, p, i);
|
|
|
|
|
}
|
|
|
|
|
// for (auto &each_var_name : var_names) {
|
|
|
|
|
// CreateOpOutput(&result, op_handle, each_var_name, p, i);
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
if (is_forwarding) {
|
|
|
|
|
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
|
|
|
|
|