|
|
|
@ -36,6 +36,15 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace imperative {
|
|
|
|
|
|
|
|
|
|
struct HashPair {
|
|
|
|
|
template <class T1, class T2>
|
|
|
|
|
size_t operator()(const std::pair<T1, T2> &p) const noexcept {
|
|
|
|
|
auto hash1 = std::hash<T1>{}(p.first);
|
|
|
|
|
auto hash2 = std::hash<T2>{}(p.second);
|
|
|
|
|
return hash1 ^ hash2;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* This function prunes the graph to get the ops between `output_targets`
|
|
|
|
|
* and `input_target_grads`.
|
|
|
|
@ -152,8 +161,10 @@ static void GetGraphInfoBetweenTargets(
|
|
|
|
|
target_vars = *input_target_grads;
|
|
|
|
|
|
|
|
|
|
std::queue<std::pair<OpBase * /*op*/, OpBase * /*pending op*/>> op_queue;
|
|
|
|
|
std::unordered_set<std::pair<OpBase *, OpBase *>, HashPair> op_base_visited;
|
|
|
|
|
for (auto &endpoint_op : endpoint_ops) {
|
|
|
|
|
op_queue.emplace(endpoint_op, nullptr);
|
|
|
|
|
op_base_visited.emplace(endpoint_op, nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
while (!op_queue.empty()) {
|
|
|
|
@ -207,6 +218,7 @@ static void GetGraphInfoBetweenTargets(
|
|
|
|
|
if (pending_op) {
|
|
|
|
|
VLOG(10) << "Pending op of " << op->Type() << " is "
|
|
|
|
|
<< pending_op->Type();
|
|
|
|
|
|
|
|
|
|
pending_ops[op].insert(pending_op);
|
|
|
|
|
++op_deps[pending_op];
|
|
|
|
|
} else {
|
|
|
|
@ -216,7 +228,10 @@ static void GetGraphInfoBetweenTargets(
|
|
|
|
|
auto iter = preceding_ops.find(op);
|
|
|
|
|
if (iter != preceding_ops.end()) {
|
|
|
|
|
for (auto &preceding_op : iter->second) {
|
|
|
|
|
op_queue.emplace(preceding_op, op);
|
|
|
|
|
if (op_base_visited.count(std::make_pair(preceding_op, op)) == 0) {
|
|
|
|
|
op_queue.emplace(preceding_op, op);
|
|
|
|
|
op_base_visited.emplace(preceding_op, op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|