|
|
|
@ -14,14 +14,18 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/scope.h"
|
|
|
|
|
#include "paddle/fluid/platform/nccl_helper.h"
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
@ -32,6 +36,16 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
places_(places),
|
|
|
|
|
local_scopes_(local_scopes),
|
|
|
|
|
nccl_ctxs_(nccl_ctxs) {
|
|
|
|
|
#else
|
|
|
|
|
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
|
const std::vector<Scope *> &local_scopes)
|
|
|
|
|
: loss_var_name_(loss_var_name),
|
|
|
|
|
places_(places),
|
|
|
|
|
local_scopes_(local_scopes) {
|
|
|
|
|
#endif
|
|
|
|
|
for (auto &p : params) {
|
|
|
|
|
grad_names_.insert(GradVarName(p));
|
|
|
|
|
}
|
|
|
|
@ -78,9 +92,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
|
|
|
|
|
if (is_forwarding) {
|
|
|
|
|
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
|
|
|
|
|
// Insert ScaleCost OpHandle
|
|
|
|
|
// Insert ScaleCost OpHandle
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
auto *communication_dev_ctx = nccl_ctxs_->DevCtx(p);
|
|
|
|
|
#else
|
|
|
|
|
auto *communication_dev_ctx =
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p,
|
|
|
|
|
nccl_ctxs_->DevCtx(p));
|
|
|
|
|
communication_dev_ctx);
|
|
|
|
|
result.ops_.emplace_back(op_handle);
|
|
|
|
|
|
|
|
|
|
// FIXME: Currently ScaleLossGradOp only use device_count as scale
|
|
|
|
@ -103,7 +124,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
auto var_names = op->OutputArgumentNames();
|
|
|
|
|
for (auto &og : var_names) {
|
|
|
|
|
if (grad_names_.count(og) != 0) { // is param grad
|
|
|
|
|
// Insert NCCL AllReduce Op
|
|
|
|
|
// Insert NCCL AllReduce Op
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result.ops_.emplace_back(
|
|
|
|
|
new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
|
|
|
|
|
auto *op_handle = result.ops_.back().get();
|
|
|
|
@ -125,6 +147,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
|
|
|
|
|
op_handle->AddOutput(&var);
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_ENFORCE("Not implemented");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -143,7 +168,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return std::unique_ptr<SSAGraph>(graph);
|
|
|
|
|
}
|
|
|
|
|
} // namespace details
|
|
|
|
|
} // namespace details
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|