|
|
@ -157,7 +157,7 @@ void MultiDevSSAGraphBuilderBase::Init() const {
|
|
|
|
places_ = Get<const std::vector<platform::Place>>(details::kPlaces);
|
|
|
|
places_ = Get<const std::vector<platform::Place>>(details::kPlaces);
|
|
|
|
local_scopes_ = Get<const std::vector<Scope *>>(details::kLocalScopes);
|
|
|
|
local_scopes_ = Get<const std::vector<Scope *>>(details::kLocalScopes);
|
|
|
|
strategy_ = Get<const details::BuildStrategy>(kStrategy);
|
|
|
|
strategy_ = Get<const details::BuildStrategy>(kStrategy);
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
|
|
|
|
multi_nccl_ctxs_ = &Get<platform::NCCLCommunicator>(details::kNCCLCtxs);
|
|
|
|
multi_nccl_ctxs_ = &Get<platform::NCCLCommunicator>(details::kNCCLCtxs);
|
|
|
|
nccl_ctxs_ = nullptr;
|
|
|
|
nccl_ctxs_ = nullptr;
|
|
|
|
if (multi_nccl_ctxs_) {
|
|
|
|
if (multi_nccl_ctxs_) {
|
|
|
@ -323,7 +323,7 @@ std::vector<ir::Node *> MultiDevSSAGraphBuilderBase::SortOperations(
|
|
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilderBase::UseGPU() const {
|
|
|
|
bool MultiDevSSAGraphBuilderBase::UseGPU() const {
|
|
|
|
bool use_gpu = false;
|
|
|
|
bool use_gpu = false;
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
|
|
|
|
use_gpu = nccl_ctxs_ != nullptr;
|
|
|
|
use_gpu = nccl_ctxs_ != nullptr;
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
return use_gpu;
|
|
|
|
return use_gpu;
|
|
|
@ -373,7 +373,7 @@ void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
|
|
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilderBase::SetCommunicationContext(
|
|
|
|
void MultiDevSSAGraphBuilderBase::SetCommunicationContext(
|
|
|
|
details::OpHandleBase *op_handle, const platform::Place &p) const {
|
|
|
|
details::OpHandleBase *op_handle, const platform::Place &p) const {
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
|
|
|
|
if (nccl_ctxs_ == nullptr) {
|
|
|
|
if (nccl_ctxs_ == nullptr) {
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
@ -392,7 +392,7 @@ void MultiDevSSAGraphBuilderBase::SetCommunicationContext(
|
|
|
|
void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result,
|
|
|
|
void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result,
|
|
|
|
const std::string &p_name,
|
|
|
|
const std::string &p_name,
|
|
|
|
size_t src_dev_id) const {
|
|
|
|
size_t src_dev_id) const {
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
|
|
|
|
auto *op_handle = new details::BroadcastOpHandle(
|
|
|
|
auto *op_handle = new details::BroadcastOpHandle(
|
|
|
|
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
|
|
|
|
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
|
|
|
|
local_scopes_, places_, nccl_ctxs_);
|
|
|
|
local_scopes_, places_, nccl_ctxs_);
|
|
|
@ -429,7 +429,7 @@ void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result,
|
|
|
|
void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
|
|
|
|
void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
|
|
|
|
ir::Graph *result,
|
|
|
|
ir::Graph *result,
|
|
|
|
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const {
|
|
|
|
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const {
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
|
|
|
|
auto *op_handle = new details::FusedBroadcastOpHandle(
|
|
|
|
auto *op_handle = new details::FusedBroadcastOpHandle(
|
|
|
|
result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation),
|
|
|
|
result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation),
|
|
|
|
local_scopes_, places_, nccl_ctxs_);
|
|
|
|
local_scopes_, places_, nccl_ctxs_);
|
|
|
@ -499,7 +499,8 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
|
|
|
|
const std::vector<Scope *> &scopes,
|
|
|
|
const std::vector<Scope *> &scopes,
|
|
|
|
const std::vector<platform::Place> &places) -> details::OpHandleBase * {
|
|
|
|
const std::vector<platform::Place> &places) -> details::OpHandleBase * {
|
|
|
|
if (is_encoded) {
|
|
|
|
if (is_encoded) {
|
|
|
|
#if defined(PADDLE_WITH_DGC) && defined(PADDLE_WITH_NCCL)
|
|
|
|
#if defined(PADDLE_WITH_DGC) && \
|
|
|
|
|
|
|
|
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(
|
|
|
|
new details::SparseAllReduceOpHandle(
|
|
|
|
new details::SparseAllReduceOpHandle(
|
|
|
|
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
|
|
|
|
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
|
|
|
@ -515,7 +516,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
|
|
|
|
grad_merge_cond_name = BOOST_GET_CONST(
|
|
|
|
grad_merge_cond_name = BOOST_GET_CONST(
|
|
|
|
std::string, node->Op()->GetAttr(GRAD_MERGE_COND_NAME));
|
|
|
|
std::string, node->Op()->GetAttr(GRAD_MERGE_COND_NAME));
|
|
|
|
VLOG(10) << "og=" << og << " use grad_merge_allreduce";
|
|
|
|
VLOG(10) << "og=" << og << " use grad_merge_allreduce";
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(
|
|
|
|
new details::GradMergeAllReduceOpHandle(
|
|
|
|
new details::GradMergeAllReduceOpHandle(
|
|
|
|
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
|
|
|
|
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
|
|
|
@ -532,7 +533,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
|
|
|
|
scopes, places, grad_merge_cond_name));
|
|
|
|
scopes, places, grad_merge_cond_name));
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
#ifdef PADDLE_WITH_NCCL
|
|
|
|
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(
|
|
|
|
new details::AllReduceOpHandle(
|
|
|
|
new details::AllReduceOpHandle(
|
|
|
|
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
|
|
|
|
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
|
|
|
@ -648,7 +649,7 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
|
|
|
|
|
|
|
|
|
|
|
|
details::VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
|
|
|
|
details::VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
|
|
|
|
ir::Graph *result, const std::string &og, size_t dst_dev_id) const {
|
|
|
|
ir::Graph *result, const std::string &og, size_t dst_dev_id) const {
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(new details::ReduceOpHandle(
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(new details::ReduceOpHandle(
|
|
|
|
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
|
|
|
|
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
|
|
|
|
local_scopes_, places_, nccl_ctxs_));
|
|
|
|
local_scopes_, places_, nccl_ctxs_));
|
|
|
|