|
|
|
@ -29,7 +29,7 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
|
|
|
|
|
|
struct OpConnectionDetector {
|
|
|
|
|
class OpConnectionDetector {
|
|
|
|
|
public:
|
|
|
|
|
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
|
|
|
|
|
|
|
|
|
@ -37,8 +37,8 @@ struct OpConnectionDetector {
|
|
|
|
|
: graph_(all_ops) {}
|
|
|
|
|
|
|
|
|
|
template <typename OpSet>
|
|
|
|
|
std::unordered_set<typename OpSet::key_type> MaxNoDepOps(
|
|
|
|
|
const OpSet &op_set) {
|
|
|
|
|
OpSet MaxNoDepOps(const OpSet &op_set) {
|
|
|
|
|
if (op_set.size() <= 1) return op_set;
|
|
|
|
|
using KeyType = typename OpSet::key_type;
|
|
|
|
|
static_assert(
|
|
|
|
|
std::is_base_of<OpHandleBase,
|
|
|
|
@ -46,7 +46,7 @@ struct OpConnectionDetector {
|
|
|
|
|
"Key type of OpSet must be or derived of OpHandleBase");
|
|
|
|
|
|
|
|
|
|
std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end());
|
|
|
|
|
std::unordered_set<KeyType> ret;
|
|
|
|
|
OpSet ret;
|
|
|
|
|
auto rels = GetRelations(ops);
|
|
|
|
|
auto not_before = [](RelationShip r) { return r != kBefore; };
|
|
|
|
|
for (size_t i = 0; i < rels.size(); ++i) {
|
|
|
|
@ -79,7 +79,7 @@ struct OpConnectionDetector {
|
|
|
|
|
auto it = op_to_idx.find(op);
|
|
|
|
|
if (it != op_to_idx.end()) {
|
|
|
|
|
size_t j = it->second;
|
|
|
|
|
if (ret[i][j] != kSame) {
|
|
|
|
|
if (i != j && ret[i][j] == kSame) {
|
|
|
|
|
ret[i][j] = kBefore;
|
|
|
|
|
ret[j][i] = kAfter;
|
|
|
|
|
found_num += 2;
|
|
|
|
@ -208,6 +208,10 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
|
|
|
|
|
VLOG(10) << "Shrink last living op number of " << var_name << " from "
|
|
|
|
|
<< original_size << " to " << last_live_op.size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(!last_live_op.empty(),
|
|
|
|
|
"Last living ops of %s cannot be empty", var_name);
|
|
|
|
|
|
|
|
|
|
ref_cnts[i].emplace(var_name, last_live_op.size());
|
|
|
|
|
last_live_ops_of_vars[i].emplace(var_name, std::move(last_live_op));
|
|
|
|
|
}
|
|
|
|
|