|
|
|
@ -156,7 +156,7 @@ void MultiDevSSAGraphBuilderBase::Init() const {
|
|
|
|
|
places_ = Get<const std::vector<platform::Place>>(details::kPlaces);
|
|
|
|
|
local_scopes_ = Get<const std::vector<Scope *>>(details::kLocalScopes);
|
|
|
|
|
strategy_ = Get<const details::BuildStrategy>(kStrategy);
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
|
multi_nccl_ctxs_ = &Get<platform::NCCLCommunicator>(details::kNCCLCtxs);
|
|
|
|
|
nccl_ctxs_ = nullptr;
|
|
|
|
|
if (multi_nccl_ctxs_) {
|
|
|
|
@ -298,7 +298,7 @@ std::vector<ir::Node *> MultiDevSSAGraphBuilderBase::SortOperations(
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilderBase::UseGPU() const {
|
|
|
|
|
bool use_gpu = false;
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
|
use_gpu = nccl_ctxs_ != nullptr;
|
|
|
|
|
#endif
|
|
|
|
|
return use_gpu;
|
|
|
|
@ -348,7 +348,7 @@ void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilderBase::SetCommunicationContext(
|
|
|
|
|
details::OpHandleBase *op_handle, const platform::Place &p) const {
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
|
if (nccl_ctxs_ == nullptr) {
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
@ -362,7 +362,7 @@ void MultiDevSSAGraphBuilderBase::SetCommunicationContext(
|
|
|
|
|
void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result,
|
|
|
|
|
const std::string &p_name,
|
|
|
|
|
size_t src_dev_id) const {
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
|
auto *op_handle = new details::BroadcastOpHandle(
|
|
|
|
|
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_);
|
|
|
|
@ -395,7 +395,7 @@ void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
|
|
|
|
|
ir::Graph *result,
|
|
|
|
|
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const {
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
|
auto *op_handle = new details::FusedBroadcastOpHandle(
|
|
|
|
|
result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_);
|
|
|
|
@ -451,7 +451,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
|
|
|
|
|
auto append_allreduce_op = [&](
|
|
|
|
|
const std::vector<Scope *> &scopes,
|
|
|
|
|
const std::vector<platform::Place> &places) -> details::OpHandleBase * {
|
|
|
|
|
#if defined(PADDLE_WITH_DGC)
|
|
|
|
|
#if defined(PADDLE_WITH_DGC) && defined(PADDLE_WITH_NCCL)
|
|
|
|
|
if (is_encoded) {
|
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(
|
|
|
|
|
new details::SparseAllReduceOpHandle(
|
|
|
|
@ -464,7 +464,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
|
|
|
|
|
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
|
|
|
|
|
scopes, places, multi_nccl_ctxs_));
|
|
|
|
|
}
|
|
|
|
|
#elif defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
#elif defined(PADDLE_WITH_NCCL)
|
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(
|
|
|
|
|
new details::AllReduceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
|
|
|
|
@ -539,7 +539,7 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
|
|
|
|
|
|
|
|
|
|
details::VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
|
|
|
|
|
ir::Graph *result, const std::string &og, size_t dst_dev_id) const {
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(new details::ReduceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|