|
|
|
@ -12,6 +12,7 @@
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include <queue>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
@ -23,6 +24,25 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
|
|
|
|
|
|
static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
|
|
|
|
|
std::queue<VarHandleBase *> queue;
|
|
|
|
|
queue.push(var_in);
|
|
|
|
|
do {
|
|
|
|
|
auto *var = queue.front();
|
|
|
|
|
queue.pop();
|
|
|
|
|
for (auto *op : var->PendingOps()) {
|
|
|
|
|
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
|
|
|
|
|
if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) {
|
|
|
|
|
return compute_op;
|
|
|
|
|
}
|
|
|
|
|
for (auto *out_var : op->Outputs()) {
|
|
|
|
|
queue.push(out_var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} while (!queue.empty());
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
|
|
|
|
|
std::unique_ptr<ir::Graph> graph) const {
|
|
|
|
|
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
|
|
|
|
@ -34,6 +54,9 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
|
|
|
|
|
// Step 2: Find all variables in non-computation ops which refers to variables
|
|
|
|
|
// in computation ops
|
|
|
|
|
std::unordered_set<std::string> names;
|
|
|
|
|
std::unordered_map<OpHandleBase *, std::unique_ptr<ReferenceCountOpHandle>>
|
|
|
|
|
compute_ref_cnt_map;
|
|
|
|
|
|
|
|
|
|
auto get_ref_cnts_from_compute_op = [&](
|
|
|
|
|
const std::unique_ptr<OpHandleBase> &op,
|
|
|
|
|
const std::vector<VarHandleBase *> &vars) {
|
|
|
|
@ -54,15 +77,18 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
|
|
|
|
|
VarDesc *var_desc = var_handle->Node()->Var();
|
|
|
|
|
auto var_name = var_handle->Node()->Name();
|
|
|
|
|
|
|
|
|
|
// This is wierd but there is really some variables without var_desc
|
|
|
|
|
// This is weird but there is really some variables without var_desc
|
|
|
|
|
// in computation_op
|
|
|
|
|
if (var_desc == nullptr) {
|
|
|
|
|
if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr)
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
if (var_desc->Persistable() ||
|
|
|
|
|
var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR)
|
|
|
|
|
if (var_desc->Persistable()) continue;
|
|
|
|
|
auto var_type = var_desc->Proto()->type().type();
|
|
|
|
|
if (var_type != proto::VarType::LOD_TENSOR &&
|
|
|
|
|
var_type != proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// compute op only runs in one device
|
|
|
|
@ -93,12 +119,33 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
|
|
|
|
|
if (ref_cnts.count(place.device) &&
|
|
|
|
|
ref_cnts[place.device]->count(var_name)) {
|
|
|
|
|
++(*ref_cnts[place.device])[var_name];
|
|
|
|
|
|
|
|
|
|
auto *next_compute_op = FindNextComputationOpHandle(var_handle);
|
|
|
|
|
if (next_compute_op != nullptr) {
|
|
|
|
|
if (compute_ref_cnt_map.count(next_compute_op)) {
|
|
|
|
|
compute_ref_cnt_map[next_compute_op]->AddVar(var_name);
|
|
|
|
|
VLOG(5) << "Add reference count of " << var_name << " to Operator "
|
|
|
|
|
<< next_compute_op->Name();
|
|
|
|
|
} else {
|
|
|
|
|
// Create new reference_count_op_handle
|
|
|
|
|
ir::Node *ref_cnt_node = graph->CreateEmptyNode(
|
|
|
|
|
"reference_count", ir::Node::Type::kOperation);
|
|
|
|
|
auto *ref_cnt_handle = new ReferenceCountOpHandle(
|
|
|
|
|
ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
|
|
|
|
|
gcs[place.device].get(), cur_ref_cnts[place.device].get());
|
|
|
|
|
if (next_compute_op->Outputs().empty()) {
|
|
|
|
|
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
|
|
|
|
|
next_compute_op->AddOutput(dep_var);
|
|
|
|
|
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
|
|
|
|
|
}
|
|
|
|
|
ref_cnt_handle->AddInput(next_compute_op->Outputs().front());
|
|
|
|
|
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
|
|
|
|
|
compute_ref_cnt_map;
|
|
|
|
|
auto &all_ops = graph->Get<GraphOps>(kGraphOps);
|
|
|
|
|
for (auto &op : all_ops) {
|
|
|
|
|
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
|
|
|
|
@ -113,11 +160,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
|
|
|
|
|
auto *ref_cnt_handle = new ReferenceCountOpHandle(
|
|
|
|
|
ref_cnt_node, compute_op->GetScope(), place, in_var_names,
|
|
|
|
|
gcs[place.device].get(), cur_ref_cnts[place.device].get());
|
|
|
|
|
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
|
|
|
|
|
compute_op->AddOutput(dep_var);
|
|
|
|
|
ref_cnt_handle->AddInput(dep_var);
|
|
|
|
|
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
|
|
|
|
|
compute_ref_cnt_map[compute_op] = ref_cnt_handle;
|
|
|
|
|
if (compute_op->Outputs().empty()) {
|
|
|
|
|
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
|
|
|
|
|
compute_op->AddOutput(dep_var);
|
|
|
|
|
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
|
|
|
|
|
}
|
|
|
|
|
ref_cnt_handle->AddInput(compute_op->Outputs().front());
|
|
|
|
|
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &op : all_ops) {
|
|
|
|
@ -131,7 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
|
|
|
|
|
new_all_ops.emplace_back(std::move(op));
|
|
|
|
|
auto it = compute_ref_cnt_map.find(new_all_ops.back().get());
|
|
|
|
|
if (it != compute_ref_cnt_map.end()) {
|
|
|
|
|
new_all_ops.emplace_back(it->second);
|
|
|
|
|
// Add LeafNode to ReferenceCountOpHandle
|
|
|
|
|
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
|
|
|
|
|
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
|
|
|
|
|
it->second->AddOutput(dummy_leaf);
|
|
|
|
|
new_all_ops.emplace_back(std::move(it->second));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|