|
|
|
@ -93,7 +93,7 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
|
|
|
|
|
}
|
|
|
|
|
var_holder.emplace_back(var);
|
|
|
|
|
} else {
|
|
|
|
|
var = var_holder.rbegin()->get();
|
|
|
|
|
var = *var_holder.rbegin();
|
|
|
|
|
}
|
|
|
|
|
return var;
|
|
|
|
|
}
|
|
|
|
@ -155,7 +155,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
|
|
|
|
|
ir::Node *node,
|
|
|
|
|
size_t place_id) const {
|
|
|
|
|
auto p = places_[place_id];
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
|
|
|
|
@ -498,7 +498,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
|
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
|
|
|
|
|
|
|
|
|
|
auto *in =
|
|
|
|
|
result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back().get();
|
|
|
|
|
result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back();
|
|
|
|
|
op_handle->AddInput(in);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
@ -535,7 +535,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
|
|
|
|
|
for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) {
|
|
|
|
|
for (auto &p_name : bcast_varnames[dev_id]) {
|
|
|
|
|
auto *in =
|
|
|
|
|
result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back().get();
|
|
|
|
|
result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back();
|
|
|
|
|
op_handle->AddInput(in);
|
|
|
|
|
for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) {
|
|
|
|
|
auto &p = places_[out_dev_id];
|
|
|
|
@ -571,7 +571,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
|
|
|
|
|
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
@ -579,7 +579,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
|
|
|
|
|
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
|
op_handle->AddInput(prev_grad);
|
|
|
|
|
|
|
|
|
|
auto var =
|
|
|
|
|
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
|
|
|
|
@ -600,14 +600,14 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
|
|
|
|
|
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
SetCommunicationContext(op_handle, p);
|
|
|
|
|
for (const std::string &d_name : datas) {
|
|
|
|
|
auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
op_handle->AddInput(vars.back().get());
|
|
|
|
|
op_handle->AddInput(vars.back());
|
|
|
|
|
auto var = new VarHandle(
|
|
|
|
|
result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
|
|
|
|
|
vars.size(), i, d_name, p);
|
|
|
|
@ -691,7 +691,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
|
|
|
|
|
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
|
|
|
|
|
local_scopes_, places_));
|
|
|
|
|
#endif
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
@ -699,7 +699,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
|
|
|
|
|
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
|
op_handle->AddInput(prev_grad);
|
|
|
|
|
}
|
|
|
|
|
auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og];
|
|
|
|
|
auto var =
|
|
|
|
@ -760,14 +760,14 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
|
|
|
|
|
for (ir::Node *input : node->inputs) {
|
|
|
|
|
VarHandle *var = nullptr;
|
|
|
|
|
for (int place_offset = 0; place_offset < num_places; ++place_offset) {
|
|
|
|
|
auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset];
|
|
|
|
|
auto &var_holder = var_holders[input->Name()];
|
|
|
|
|
if (!var_holder.empty()) {
|
|
|
|
|
var = var_holder.rbegin()->get();
|
|
|
|
|
var = *var_holder.rbegin();
|
|
|
|
|
op_handle->AddInput(var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -840,7 +840,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
|
|
|
|
|
// send_barrier, recv, fetch_barrier's inputs are deps var, get them from
|
|
|
|
|
// all places
|
|
|
|
|
auto p = places_[op_dev_id];
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get();
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
|
|
|
|
|