|
|
|
@ -78,6 +78,33 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
|
|
|
|
|
OpDesc *send_op) const {
|
|
|
|
|
if (send_op == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto checker = [&](const std::vector<std::string> opvars,
|
|
|
|
|
const std::vector<std::string> sendvars) -> bool {
|
|
|
|
|
bool is_dist_train_op = false;
|
|
|
|
|
for (auto &var : opvars) {
|
|
|
|
|
if (var.find(".block") != std::string::npos &&
|
|
|
|
|
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
|
|
|
|
|
is_dist_train_op = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return is_dist_train_op;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (op.Type() == "split") {
|
|
|
|
|
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
|
|
|
|
|
} else if (op.Type() == "concat") {
|
|
|
|
|
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
const ProgramDesc &program) const {
|
|
|
|
|
auto graph = new SSAGraph();
|
|
|
|
@ -89,19 +116,30 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
|
|
|
|
|
places_.size());
|
|
|
|
|
|
|
|
|
|
// Find "send" op first for split is in front of send.
|
|
|
|
|
OpDesc *send_op = nullptr;
|
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
|
if (op->Type() == "send") {
|
|
|
|
|
send_op = op;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool is_forwarding = true;
|
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
|
if (op->Type() == "send") {
|
|
|
|
|
// append send op if program is distributed trainer main program.
|
|
|
|
|
// always use the first device
|
|
|
|
|
CreateSendOp(&result, *op);
|
|
|
|
|
} else if (IsDistTrainOp(*op, send_op)) {
|
|
|
|
|
CreateComputationalOps(&result, *op, 1);
|
|
|
|
|
} else if (IsScaleLossOp(*op)) {
|
|
|
|
|
if (!skip_scale_loss_) {
|
|
|
|
|
CreateScaleLossGradOp(&result);
|
|
|
|
|
}
|
|
|
|
|
is_forwarding = false;
|
|
|
|
|
} else {
|
|
|
|
|
CreateComputationalOps(&result, *op);
|
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|
|
if (!is_forwarding) {
|
|
|
|
|
// Currently, we assume that once gradient is generated, it can be
|
|
|
|
|
// broadcast, and each gradient is only broadcast once. But there are no
|
|
|
|
@ -199,8 +237,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
|
|
|
|
|
const OpDesc &op) const {
|
|
|
|
|
for (size_t scope_idx = 0; scope_idx < places_.size(); ++scope_idx) {
|
|
|
|
|
const OpDesc &op,
|
|
|
|
|
size_t num_places) const {
|
|
|
|
|
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
|
|
|
|
|
auto p = places_[scope_idx];
|
|
|
|
|
auto s = local_scopes_[scope_idx];
|
|
|
|
|
result->ops_.emplace_back(new ComputationOpHandle(op, s, p));
|
|
|
|
|