parent
dcece75b57
commit
2f47562df8
@ -1,6 +1,11 @@
|
|||||||
|
# ddim lib
|
||||||
cc_library(ddim SRCS ddim.cc)
|
cc_library(ddim SRCS ddim.cc)
|
||||||
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
|
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
|
||||||
|
|
||||||
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
|
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
|
||||||
|
|
||||||
cc_test(variable_test SRCS variable_test.cc)
|
cc_test(variable_test SRCS variable_test.cc)
|
||||||
|
|
||||||
|
# scope lib
|
||||||
|
cc_library(scope SRCS scope.cc)
|
||||||
|
cc_test(scope_test SRCS scope_test.cc DEPS scope)
|
||||||
|
@ -0,0 +1,54 @@
|
|||||||
|
#include "paddle/framework/scope.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
Error Scope::CreateVariable(const std::string &name) {
|
||||||
|
if (name == "") {
|
||||||
|
return Error("Variable name should not be empty");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (HaveVariable(name)) {
|
||||||
|
return AlreadyCreated;
|
||||||
|
}
|
||||||
|
vars_[name] = std::unique_ptr<Variable>(new Variable());
|
||||||
|
return Error();
|
||||||
|
}
|
||||||
|
|
||||||
|
Variable* Scope::GetVarLocally(const std::string& name) const {
|
||||||
|
if (vars_.count(name)) {
|
||||||
|
return vars_.at(name).get();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Variable* Scope::GetVariable(const std::string &name) const {
|
||||||
|
Variable* var = GetVarLocally(name);
|
||||||
|
if (var != nullptr) {
|
||||||
|
return var;
|
||||||
|
} else if (parent_ != nullptr) {
|
||||||
|
return parent_->GetVariable(name);
|
||||||
|
} else {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Variable* Scope::GetOrCreateVariable(const std::string &name) {
|
||||||
|
Variable* var;
|
||||||
|
var = GetVariable(name);
|
||||||
|
if (var == nullptr) {
|
||||||
|
auto err = CreateVariable(name);
|
||||||
|
if (!err.isOK()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return GetVariable(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Scope::HaveVariable(const std::string &name) {
|
||||||
|
return vars_.count(name) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
@ -0,0 +1,51 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "paddle/framework/variable.h"
|
||||||
|
#include "paddle/utils/Error.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
const static Error AlreadyCreated("Variable has already been created");
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Scope is an association of a name to Variable. All variables belong to `Scope`.
|
||||||
|
* You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. One net can
|
||||||
|
* run in different scopes and update different variable in the scope.
|
||||||
|
*/
|
||||||
|
class Scope {
|
||||||
|
public:
|
||||||
|
Scope() {}
|
||||||
|
|
||||||
|
explicit Scope(const std::shared_ptr<Scope> &scope):
|
||||||
|
parent_(scope) {}
|
||||||
|
|
||||||
|
~Scope() {}
|
||||||
|
|
||||||
|
// Create Variable in this Scope. Return error if Variable already been
|
||||||
|
// created.
|
||||||
|
Error __must_check CreateVariable(const std::string& name);
|
||||||
|
|
||||||
|
// Get Variable from this Scope, this function will recursive find Variable
|
||||||
|
// from it's parent scope.
|
||||||
|
// Return nullptr if not found.
|
||||||
|
Variable* GetVariable(const std::string& name) const;
|
||||||
|
|
||||||
|
// find and return Variables in the scope it self.
|
||||||
|
Variable* GetVarLocally(const std::string& name) const;
|
||||||
|
|
||||||
|
// Get a Variable from Scope, if the Variable is not exist then create it.
|
||||||
|
// User should call this function most of time.
|
||||||
|
Variable* GetOrCreateVariable(const std::string& name);
|
||||||
|
|
||||||
|
bool HaveVariable(const std::string& name);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
|
||||||
|
std::shared_ptr<Scope> parent_ {nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,47 @@
|
|||||||
|
#include "paddle/framework/scope.h"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
TEST(Scope, Create) {
|
||||||
|
using paddle::framework::Scope;
|
||||||
|
using paddle::Error;
|
||||||
|
using paddle::framework::Variable;
|
||||||
|
using paddle::framework::AlreadyCreated;
|
||||||
|
|
||||||
|
Scope* scope = new Scope();
|
||||||
|
|
||||||
|
Error err = scope->CreateVariable("");
|
||||||
|
EXPECT_FALSE(err.isOK());
|
||||||
|
|
||||||
|
Variable* var1 = scope->GetVariable("a");
|
||||||
|
EXPECT_EQ(var1, nullptr);
|
||||||
|
|
||||||
|
Error err1 = scope->CreateVariable("a");
|
||||||
|
EXPECT_TRUE(err1.isOK());
|
||||||
|
|
||||||
|
Error err2 = scope->CreateVariable("a");
|
||||||
|
EXPECT_EQ(err2, AlreadyCreated);
|
||||||
|
|
||||||
|
Variable* var2 = scope->GetVariable("a");
|
||||||
|
EXPECT_NE(var2, nullptr);
|
||||||
|
|
||||||
|
Variable* var3 = scope->GetOrCreateVariable("b");
|
||||||
|
EXPECT_NE(var3, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Scope, Parent) {
|
||||||
|
using paddle::framework::Scope;
|
||||||
|
using paddle::framework::Variable;
|
||||||
|
using paddle::Error;
|
||||||
|
|
||||||
|
const auto parent_scope_ptr = std::shared_ptr<Scope>(new Scope());
|
||||||
|
Scope* scope = new Scope(parent_scope_ptr);
|
||||||
|
|
||||||
|
Error err = parent_scope_ptr->CreateVariable("a");
|
||||||
|
EXPECT_TRUE(err.isOK());
|
||||||
|
|
||||||
|
Variable* var1 = scope->GetVarLocally("a");
|
||||||
|
EXPECT_EQ(var1, nullptr);
|
||||||
|
|
||||||
|
Variable* var2 = scope->GetVariable("a");
|
||||||
|
EXPECT_NE(var2, nullptr);
|
||||||
|
}
|
Loading…
Reference in new issue