|
|
|
@ -17,10 +17,10 @@
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/reduce_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/rpc_op_handle.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
|
|
|
|
@ -283,6 +283,19 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::SetCommunicationContext(
|
|
|
|
|
OpHandleBase *op_handle, const platform::Place &p) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (nccl_ctxs_ == nullptr) {
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
|
|
|
|
|
const std::string &p_name,
|
|
|
|
|
size_t src_dev_id) const {
|
|
|
|
@ -306,19 +319,6 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::SetCommunicationContext(
|
|
|
|
|
OpHandleBase *op_handle, const platform::Place &p) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (nccl_ctxs_ == nullptr) {
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
|
|
|
|
|
const OpDesc &op,
|
|
|
|
|
int dev_id) const {
|
|
|
|
@ -331,9 +331,9 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
|
|
|
|
|
SSAGraph *result, const std::string &og) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result->ops_.emplace_back(
|
|
|
|
|
new NCCLAllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
|
|
|
|
|
#else
|
|
|
|
|
result->ops_.emplace_back(new NCCLAllReduceOpHandle(local_scopes_, places_));
|
|
|
|
|
result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->ops_.back().get();
|
|
|
|
|
|
|
|
|
|