|
|
|
@ -142,7 +142,7 @@ void MultiDevSSAGraphBuilder::Init() const {
|
|
|
|
|
places_ = Get<const std::vector<platform::Place>>(kPlaces);
|
|
|
|
|
local_scopes_ = Get<const std::vector<Scope *>>(kLocalScopes);
|
|
|
|
|
strategy_ = Get<const BuildStrategy>(kStrategy);
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
@ -431,7 +431,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
bool use_gpu = false;
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
use_gpu = nccl_ctxs_ != nullptr;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
@ -478,7 +478,7 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::SetCommunicationContext(
|
|
|
|
|
OpHandleBase *op_handle, const platform::Place &p) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
if (nccl_ctxs_ == nullptr) {
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
@ -492,7 +492,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
|
|
|
|
|
const std::string &p_name,
|
|
|
|
|
size_t src_dev_id) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
auto *op_handle = new BroadcastOpHandle(
|
|
|
|
|
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_);
|
|
|
|
@ -522,7 +522,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
|
|
|
|
|
ir::Graph *result,
|
|
|
|
|
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
auto *op_handle = new FusedBroadcastOpHandle(
|
|
|
|
|
result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_);
|
|
|
|
@ -568,7 +568,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
|
|
|
|
|
const std::string &og) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_));
|
|
|
|
@ -597,7 +597,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
|
|
|
|
|
ir::Graph *result, const std::vector<std::string> &datas) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(new DataBalanceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_));
|
|
|
|
@ -694,7 +694,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
|
|
|
|
|
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
|
|
|
|
|
const std::string &og,
|
|
|
|
|
int dst_dev_id) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|