Polish `Scope::LocalVarNames`

Cannot get var name recursive since they could be same.
del_some_in_makelist
Yang Yu 7 years ago
parent f839154542
commit ef188371a1

@ -74,17 +74,9 @@ void Scope::DropKids() {
kids_.clear();
}
std::vector<std::string> Scope::GetAllNames(bool recursive) const {
std::vector<std::string> known_vars(vars_.size());
if (recursive) {
for (auto& kid : kids_) {
auto kid_vars = kid->GetAllNames();
for (auto& p : kid_vars) {
known_vars.emplace_back(p);
}
}
}
std::vector<std::string> Scope::LocalVarNames() const {
std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size());
for (auto& p : vars_) {
known_vars.emplace_back(p.first);
}

@ -66,7 +66,7 @@ class Scope {
void DropKids();
// enumerate all the variables current contains.
std::vector<std::string> GetAllNames(bool recursive = false) const;
std::vector<std::string> LocalVarNames() const;
// Rename variable to a new name
void Rename(const std::string& origin_name,

@ -61,7 +61,7 @@ TEST(Scope, GetAllNames) {
Variable* v = s.Var("a");
EXPECT_EQ(&s, s.FindScope(v));
std::vector<std::string> ans = s.GetAllNames();
std::vector<std::string> ans = s.LocalVarNames();
std::string str;
for (auto& var : ans) {
str += var;

@ -491,7 +491,7 @@ class RecurrentGradOp : public RecurrentBase {
std::unordered_set<std::string> LocalVarNames(
const framework::Scope &scope) const {
return this->List2Set(scope.GetAllNames(false));
return this->List2Set(scope.LocalVarNames());
}
static std::vector<std::string> GradVarLists(
const std::vector<std::string> &var_names) {

Loading…
Cancel
Save