Merge pull request #7030 from reyoung/feature/fix_get_all_names

Polish `Scope::LocalVarNames`
del_some_in_makelist
Yu Yang 7 years ago committed by GitHub
commit 19f2475af1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -74,17 +74,9 @@ void Scope::DropKids() {
kids_.clear(); kids_.clear();
} }
std::vector<std::string> Scope::GetAllNames(bool recursive) const { std::vector<std::string> Scope::LocalVarNames() const {
std::vector<std::string> known_vars(vars_.size()); std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size());
if (recursive) {
for (auto& kid : kids_) {
auto kid_vars = kid->GetAllNames();
for (auto& p : kid_vars) {
known_vars.emplace_back(p);
}
}
}
for (auto& p : vars_) { for (auto& p : vars_) {
known_vars.emplace_back(p.first); known_vars.emplace_back(p.first);
} }

@ -66,7 +66,7 @@ class Scope {
void DropKids(); void DropKids();
// enumerate all the variables current contains. // enumerate all the variables current contains.
std::vector<std::string> GetAllNames(bool recursive = false) const; std::vector<std::string> LocalVarNames() const;
// Rename variable to a new name // Rename variable to a new name
void Rename(const std::string& origin_name, void Rename(const std::string& origin_name,

@ -61,7 +61,7 @@ TEST(Scope, GetAllNames) {
Variable* v = s.Var("a"); Variable* v = s.Var("a");
EXPECT_EQ(&s, s.FindScope(v)); EXPECT_EQ(&s, s.FindScope(v));
std::vector<std::string> ans = s.GetAllNames(); std::vector<std::string> ans = s.LocalVarNames();
std::string str; std::string str;
for (auto& var : ans) { for (auto& var : ans) {
str += var; str += var;

@ -491,7 +491,7 @@ class RecurrentGradOp : public RecurrentBase {
std::unordered_set<std::string> LocalVarNames( std::unordered_set<std::string> LocalVarNames(
const framework::Scope &scope) const { const framework::Scope &scope) const {
return this->List2Set(scope.GetAllNames(false)); return this->List2Set(scope.LocalVarNames());
} }
static std::vector<std::string> GradVarLists( static std::vector<std::string> GradVarLists(
const std::vector<std::string> &var_names) { const std::vector<std::string> &var_names) {

Loading…
Cancel
Save