|
|
|
@ -625,19 +625,11 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
|
|
|
|
|
ir::Graph *result, const std::string &loss_grad_name) const {
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
// Insert ScaleCost OpHandle
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
auto *communication_dev_ctx =
|
|
|
|
|
nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i])
|
|
|
|
|
: platform::DeviceContextPool::Instance().Get(places_[i]);
|
|
|
|
|
#else
|
|
|
|
|
auto *communication_dev_ctx =
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
|
|
|
|
|
#endif
|
|
|
|
|
// Insert ScaleCost OpHandle
|
|
|
|
|
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
|
|
|
|
|
auto *op_handle = new ScaleLossGradOpHandle(
|
|
|
|
|
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_.size(), local_scopes_[i], places_[i],
|
|
|
|
|
communication_dev_ctx);
|
|
|
|
|
local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx);
|
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
|
|
|
|
|
|
|
|
|
|
// FIXME: Currently ScaleLossGradOp only use device_count as scale
|
|
|
|
|