|
|
|
@ -20,16 +20,13 @@
|
|
|
|
|
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/reduce_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/rpc_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_info.h"
|
|
|
|
|
#include "paddle/fluid/framework/scope.h"
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
@ -305,7 +302,12 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
|
|
|
|
|
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
|
|
|
|
|
vars.emplace_back(out_var);
|
|
|
|
|
op_handle->AddOutput(out_var);
|
|
|
|
|
#ifndef ADDLE_WITH_CUDA
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (nccl_ctxs_ == nullptr) {
|
|
|
|
|
op_handle->SetDeviceContext(
|
|
|
|
|
p, platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
#endif
|
|
|
|
@ -324,7 +326,10 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
|
|
|
|
|
SSAGraph *result, const std::string &og) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result->ops_.emplace_back(
|
|
|
|
|
new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
|
|
|
|
|
new NCCLAllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
result->ops_.emplace_back(new NCCLAllReduceOpHandle(local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->ops_.back().get();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
@ -334,13 +339,23 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
|
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (nccl_ctxs_ == nullptr) {
|
|
|
|
|
op_handle->SetDeviceContext(
|
|
|
|
|
p, platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
VLOG(4) << "NCCL - - - " << p;
|
|
|
|
|
op_handle->DeviceContext(p)->Wait();
|
|
|
|
|
VLOG(4) << "NCCL - - - " << p << " " << op_handle->DeviceContext(p);
|
|
|
|
|
auto var = new VarHandle(vars.size() - 1, i, og, p);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_ENFORCE("Not implemented");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
|
|
|
|
@ -379,7 +394,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
// Insert ScaleCost OpHandle
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
auto *communication_dev_ctx = nccl_ctxs_->DevCtx(places_[i]);
|
|
|
|
|
auto *communication_dev_ctx =
|
|
|
|
|
nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i])
|
|
|
|
|
: platform::DeviceContextPool::Instance().Get(places_[i]);
|
|
|
|
|
#else
|
|
|
|
|
auto *communication_dev_ctx =
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
|
|
|
|
@ -425,8 +442,13 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &vars = result->vars_[i][og];
|
|
|
|
|
#ifndef PADDLE_WITH_CUDA
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (nccl_ctxs_ == nullptr) {
|
|
|
|
|
op_handle->SetDeviceContext(
|
|
|
|
|
p, platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
#endif
|
|
|
|
|