small refine (#11460)

port
chengduo 7 years ago committed by GitHub
parent ab0c2e1dab
commit bb29800aaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -207,14 +207,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
is_forwarding = false;
} else {
int op_dev_id = GetOpDeviceID(*op);
if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
} else {
if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices_.emplace(var_name, op_dev_id);
}
}
} else {
// This op runs on all devices, and its output may have parameter's
// gradients.
CreateComputationalOps(&result, *op, places_.size());
if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
@ -259,6 +261,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
}
}
}
// Insert BCast Ops
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {

Loading…
Cancel
Save