|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
@ -52,13 +53,28 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
|
|
|
|
|
// Note that must assert topology sort is stable
|
|
|
|
|
auto& ops = graph->Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
|
|
|
|
|
for (auto* op_desc : ops) {
|
|
|
|
|
auto outputs = op_desc->Outputs();
|
|
|
|
|
for (auto& o_it : outputs) {
|
|
|
|
|
for (auto& v : o_it.second) { // values
|
|
|
|
|
vars[v] = order;
|
|
|
|
|
try {
|
|
|
|
|
bool is_bk_op =
|
|
|
|
|
static_cast<bool>(boost::get<int>(op_desc->GetAttr(
|
|
|
|
|
OpProtoAndCheckerMaker::OpRoleAttrName())) &
|
|
|
|
|
static_cast<int>(OpRole::kBackward));
|
|
|
|
|
if (!is_bk_op) continue;
|
|
|
|
|
|
|
|
|
|
auto backward_vars =
|
|
|
|
|
boost::get<std::vector<std::string>>(op_desc->GetNullableAttr(
|
|
|
|
|
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
|
|
|
|
|
|
|
|
|
|
auto outputs = op_desc->Outputs();
|
|
|
|
|
for (auto& o_it : outputs) {
|
|
|
|
|
for (auto& v : o_it.second) { // values
|
|
|
|
|
vars[v] = order;
|
|
|
|
|
VLOG(1) << "in all_reduce_deps_pass:" << v;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
order++;
|
|
|
|
|
} catch (boost::bad_get e) {
|
|
|
|
|
}
|
|
|
|
|
order++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<OpHandleBase*> dist_ops;
|
|
|
|
|