|
|
|
@ -207,14 +207,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
is_forwarding = false;
|
|
|
|
|
} else {
|
|
|
|
|
int op_dev_id = GetOpDeviceID(*op);
|
|
|
|
|
if (op_dev_id == -1) { // var on all device
|
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|
|
} else {
|
|
|
|
|
if (op_dev_id != -1) { // This op only runs on one specific device.
|
|
|
|
|
CreateComputationalOp(&result, *op, op_dev_id);
|
|
|
|
|
for (auto &var_name : op->OutputArgumentNames()) {
|
|
|
|
|
var_name_on_devices_.emplace(var_name, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// This op runs on all devices, and its output may have parameter's
|
|
|
|
|
// gradients.
|
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|
|
|
|
|
|
|
if (!is_forwarding && places_.size() > 1) {
|
|
|
|
|
// Currently, we assume that once gradient is generated, it can be
|
|
|
|
|
// broadcast, and each gradient is only broadcast once.
|
|
|
|
@ -259,6 +261,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Insert BCast Ops
|
|
|
|
|
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
|
|
|
|
|