|
|
|
@ -147,15 +147,52 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
|
|
|
|
|
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto get_vars = [](std::deque<std::unique_ptr<OpDesc>>::iterator &op,
|
|
|
|
|
std::vector<std::string> &v) {
|
|
|
|
|
auto in_names = (*op)->InputArgumentNames();
|
|
|
|
|
v.insert(v.end(), in_names.begin(), in_names.end());
|
|
|
|
|
auto out_names = (*op)->OutputArgumentNames();
|
|
|
|
|
v.insert(v.end(), out_names.begin(), out_names.end());
|
|
|
|
|
std::sort(v.begin(), v.end());
|
|
|
|
|
auto last = std::unique(v.begin(), v.end());
|
|
|
|
|
v.erase(last, v.end());
|
|
|
|
|
};
|
|
|
|
|
need_update_ = true;
|
|
|
|
|
for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) {
|
|
|
|
|
auto names = (*it)->InputArgumentNames();
|
|
|
|
|
for (auto n : names) {
|
|
|
|
|
// TODO(typhoonzero): delete vars if no other op use it.
|
|
|
|
|
VLOG(3) << "deleting var " << n;
|
|
|
|
|
|
|
|
|
|
for (size_t i = s; i < e; i++) {
|
|
|
|
|
// since remove op one by one, every time remove the first op.
|
|
|
|
|
auto op = ops_.begin() + s;
|
|
|
|
|
|
|
|
|
|
// collect input and output variables from current delete op
|
|
|
|
|
std::vector<std::string> cur_vars;
|
|
|
|
|
get_vars(op, cur_vars);
|
|
|
|
|
|
|
|
|
|
// remove current op
|
|
|
|
|
ops_.erase(ops_.begin() + s);
|
|
|
|
|
|
|
|
|
|
// collect input and output variables from other ops
|
|
|
|
|
std::vector<std::string> other_vars;
|
|
|
|
|
for (auto it = ops_.begin(); it != ops_.end(); it++) {
|
|
|
|
|
get_vars(it, other_vars);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// variables should be deleted
|
|
|
|
|
std::vector<std::string> delete_vars;
|
|
|
|
|
// delete_vars = cur_vars - cur_vars ^ other_input_vars
|
|
|
|
|
std::set_difference(cur_vars.begin(), cur_vars.end(), other_vars.begin(),
|
|
|
|
|
other_vars.end(),
|
|
|
|
|
std::inserter(delete_vars, delete_vars.end()));
|
|
|
|
|
// remove variables
|
|
|
|
|
for (size_t i = 0; i < delete_vars.size(); i++) {
|
|
|
|
|
auto name = delete_vars[i];
|
|
|
|
|
auto it = vars_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(it != vars_.end(),
|
|
|
|
|
"%s is not in variable list, it should not be deleted",
|
|
|
|
|
name);
|
|
|
|
|
vars_.erase(it);
|
|
|
|
|
VLOG(3) << "deleting variable " << name;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ops_.erase(ops_.begin() + s, ops_.begin() + e);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<OpDesc *> BlockDesc::AllOps() const {
|
|
|
|
|