You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/framework/details/reference_count_pass.cc

228 lines
7.3 KiB

// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <queue>
#include <string>
6 years ago
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
6 years ago
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
6 years ago
struct OpConnectionDetector {
public:
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
explicit OpConnectionDetector(const std::vector<OpHandleBase *> &all_ops)
: graph_(all_ops) {}
template <typename OpSet>
std::unordered_set<typename OpSet::key_type> MaxNoDepOps(
const OpSet &op_set) {
using KeyType = typename OpSet::key_type;
static_assert(
std::is_base_of<OpHandleBase,
typename std::remove_pointer<KeyType>::type>::value,
"Key type of OpSet must be or derived of OpHandleBase");
std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end());
std::unordered_set<KeyType> ret;
auto rels = GetRelations(ops);
auto not_before = [](RelationShip r) { return r != kBefore; };
for (size_t i = 0; i < rels.size(); ++i) {
if (std::all_of(rels[i].begin(), rels[i].end(), not_before)) {
ret.insert(static_cast<KeyType>(ops[i]));
}
}
return ret;
}
private:
std::vector<std::vector<RelationShip>> GetRelations(
const std::vector<OpHandleBase *> ops) {
std::unordered_map<OpHandleBase *, size_t> op_to_idx;
for (size_t i = 0; i < ops.size(); ++i) {
PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph");
op_to_idx[ops[i]] = i;
}
PADDLE_ENFORCE(op_to_idx.size() == ops.size(), "Duplicate ops");
std::vector<std::vector<RelationShip>> ret(ops.size());
for (auto &e : ret) {
e.assign(ops.size(), kSame);
}
size_t found_num = ops.size();
size_t total_num = ops.size() * ops.size();
auto visitor = [&](OpHandleBase *op, size_t i) {
auto it = op_to_idx.find(op);
if (it != op_to_idx.end()) {
size_t j = it->second;
if (ret[i][j] != kSame) {
ret[i][j] = kBefore;
ret[j][i] = kAfter;
found_num += 2;
if (found_num == total_num) {
return false;
}
}
}
return true;
};
for (size_t i = 0; i < ops.size(); ++i) {
auto sub_visitor = [&, i](OpHandleBase *op) { return visitor(op, i); };
if (!graph_.VisitAllPendingOps(ops[i], sub_visitor)) {
break;
}
}
for (size_t i = 0; i < ops.size(); ++i) {
for (size_t j = i + 1; j < ops.size(); ++j) {
if (ret[i][j] != kSame) continue;
ret[i][j] = kNoDeps;
ret[j][i] = kNoDeps;
}
}
return ret;
}
const OpGraphView graph_;
};
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
q.push(op);
do {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
for (auto *pending_op : out_var->PendingOps()) {
if (visited.count(pending_op)) continue;
visited.insert(pending_op);
}
}
} while (!q.empty());
return nullptr;
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto &vars = graph->Get<GraphVars>(kGraphVars);
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
last_live_ops_of_vars = std::vector<LastLiveOpsOfVars>(vars.size());
ref_cnts = std::vector<ReferenceCountMap>(vars.size());
6 years ago
OpConnectionDetector detector(ir::FilterByNodeWrapper<OpHandleBase>(*graph));
for (size_t i = 0; i < vars.size(); ++i) {
for (auto &name_var_pair : vars[i]) {
6 years ago
if (name_var_pair.second.empty()) {
continue;
}
const std::string &var_name = name_var_pair.first;
auto *last_ver_var = name_var_pair.second.back();
VarDesc *var_desc = nullptr;
std::find_if(name_var_pair.second.rbegin(), name_var_pair.second.rend(),
[&](VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
if (var_desc == nullptr || var_desc->Persistable()) {
continue;
6 years ago
}
auto var_type = var_desc->Proto()->type().type();
if (var_type != proto::VarType::LOD_TENSOR &&
var_type != proto::VarType::SELECTED_ROWS &&
var_type != proto::VarType::LOD_TENSOR_ARRAY) {
6 years ago
continue;
}
std::unordered_set<ComputationOpHandle *> last_live_op;
6 years ago
auto add_last_live_op = [&](OpHandleBase *op) -> bool {
auto *compute_op = FindNextComputationOpHandleOrReturnItself(op, i);
if (compute_op) {
last_live_op.insert(compute_op);
6 years ago
return true;
} else {
return false;
}
};
6 years ago
bool can_delete = false;
auto &pending_ops = last_ver_var->PendingOps();
if (pending_ops.empty()) {
auto *generated_op = last_ver_var->GeneratedOp();
6 years ago
if (generated_op && add_last_live_op(generated_op)) {
can_delete = true;
}
} else {
6 years ago
can_delete = true;
for (auto *pending_op : pending_ops) {
6 years ago
if (!add_last_live_op(pending_op)) {
can_delete = false;
break;
}
}
}
6 years ago
if (can_delete) {
size_t original_size = last_live_op.size();
last_live_op = detector.MaxNoDepOps(last_live_op);
if (last_live_op.size() != original_size) {
VLOG(10) << "Shrink last living op number of " << var_name << " from "
<< original_size << " to " << last_live_op.size();
}
ref_cnts[i].emplace(var_name, last_live_op.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(last_live_op));
}
}
}
6 years ago
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(reference_count_pass,
paddle::framework::details::ReferenceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars);