|
|
|
@ -326,7 +326,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
|
ir::Graph &result = *graph;
|
|
|
|
|
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
if (node->NodeType() == ir::Node::Type::kVariable && node->Var()) {
|
|
|
|
|
if (node->IsVar() && node->Var()) {
|
|
|
|
|
all_vars_.emplace(node->Name(), node->Var());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -583,18 +583,6 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
|
|
|
|
|
const std::string &og,
|
|
|
|
|
std::unordered_set<std::string> *og_has_been_broadcast) const {
|
|
|
|
|
bool is_pg_once =
|
|
|
|
|
grad_names_.count(og) != 0 && og_has_been_broadcast->count(og) == 0;
|
|
|
|
|
if (is_pg_once) {
|
|
|
|
|
// Insert NCCL AllReduce Op
|
|
|
|
|
og_has_been_broadcast->insert(og);
|
|
|
|
|
}
|
|
|
|
|
return is_pg_once;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
|
|
|
|
@ -688,20 +676,6 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
|
|
|
|
|
return var;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Find the first occurence of `prev_op_name` and make current `op` depend
|
|
|
|
|
// on it.
|
|
|
|
|
void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
|
|
|
|
|
const std::string &prev_op_name) const {
|
|
|
|
|
for (auto &prev_op : result->Get<GraphOps>(kGraphOps)) {
|
|
|
|
|
if (prev_op->Name() == prev_op_name) {
|
|
|
|
|
auto *dep_var = new DummyVarHandle(result->CreateControlDepVar());
|
|
|
|
|
prev_op->AddOutput(dep_var);
|
|
|
|
|
result->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
|
|
|
|
|
op->AddInput(dep_var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
int op_dev_id = -1;
|
|
|
|
|