`Paddle` manages Scope as programming language's scope. It just a thread-local stack of Scope. Top of that stack is current scope, the bottom of that stack is all scopes' parent. Invoking `create_var/get_var` can `create/get` variable in current scope. Invoking `enter_local_scope/leave_local_scope` can create or destroy local scope. A `scoped_function` will take a `function` as input. That function will be invoked in a new local scope.cblas_new
parent
267f9a2cdf
commit
d027f47d7d
@ -0,0 +1,83 @@
|
||||
"""
|
||||
Default scope function.
|
||||
|
||||
`Paddle` manages Scope as programming language's scope. It just a
|
||||
thread-local stack of Scope. Top of that stack is current scope, the bottom
|
||||
of that stack is all scopes' parent.
|
||||
|
||||
Invoking `create_var/get_var` can `create/get` variable in current scope.
|
||||
Invoking `enter_local_scope/leave_local_scope` can create or destroy local
|
||||
scope.
|
||||
|
||||
A `scoped_function` will take a `function` as input. That function will be
|
||||
invoked in a new local scope.
|
||||
"""
|
||||
|
||||
import paddle.v2.framework.core
|
||||
import threading
|
||||
|
||||
__tl_scope__ = threading.local()
|
||||
|
||||
__all__ = [
|
||||
'get_cur_scope', 'enter_local_scope', 'leave_local_scope', 'create_var',
|
||||
'get_var', 'scoped_function'
|
||||
]
|
||||
|
||||
|
||||
def get_cur_scope():
|
||||
"""
|
||||
Get current scope.
|
||||
:rtype: paddle.v2.framework.core.Scope
|
||||
"""
|
||||
cur_scope_stack = getattr(__tl_scope__, 'cur_scope', None)
|
||||
if cur_scope_stack is None:
|
||||
__tl_scope__.cur_scope = list()
|
||||
if len(__tl_scope__.cur_scope) == 0:
|
||||
__tl_scope__.cur_scope.append(paddle.v2.framework.core.Scope(None))
|
||||
return __tl_scope__.cur_scope[-1]
|
||||
|
||||
|
||||
def enter_local_scope():
|
||||
"""
|
||||
Enter a new local scope
|
||||
"""
|
||||
cur_scope = get_cur_scope()
|
||||
new_scope = paddle.v2.framework.core.Scope(cur_scope)
|
||||
__tl_scope__.cur_scope.append(new_scope)
|
||||
|
||||
|
||||
def leave_local_scope():
|
||||
"""
|
||||
Leave local scope
|
||||
"""
|
||||
__tl_scope__.cur_scope.pop()
|
||||
|
||||
|
||||
def create_var(name):
|
||||
"""
|
||||
create variable in current scope.
|
||||
"""
|
||||
return get_cur_scope().create_var(name)
|
||||
|
||||
|
||||
def get_var(name):
|
||||
"""
|
||||
get variable in current scope.
|
||||
"""
|
||||
return get_cur_scope().get_var(name)
|
||||
|
||||
|
||||
def scoped_function(func):
|
||||
"""
|
||||
invoke `func` in new scope.
|
||||
|
||||
:param func: a callable function that will be run in new scope.
|
||||
:type func: callable
|
||||
"""
|
||||
enter_local_scope()
|
||||
try:
|
||||
func()
|
||||
except:
|
||||
raise
|
||||
finally:
|
||||
leave_local_scope()
|
@ -1 +1,2 @@
|
||||
add_python_test(test_framework test_protobuf.py test_scope.py)
|
||||
add_python_test(test_framework test_protobuf.py test_scope.py
|
||||
test_default_scope_funcs.py)
|
||||
|
@ -0,0 +1,33 @@
|
||||
from paddle.v2.framework.default_scope_funcs import *
|
||||
import unittest
|
||||
|
||||
|
||||
class TestDefaultScopeFuncs(unittest.TestCase):
|
||||
def test_cur_scope(self):
|
||||
self.assertIsNotNone(get_cur_scope())
|
||||
|
||||
def test_none_variable(self):
|
||||
self.assertIsNone(get_var("test"))
|
||||
|
||||
def test_create_var_get_var(self):
|
||||
var_a = create_var("var_a")
|
||||
self.assertIsNotNone(var_a)
|
||||
self.assertIsNotNone(get_cur_scope().get_var('var_a'))
|
||||
enter_local_scope()
|
||||
self.assertIsNotNone(get_cur_scope().get_var('var_a'))
|
||||
leave_local_scope()
|
||||
|
||||
def test_var_get_int(self):
|
||||
def __new_scope__():
|
||||
i = create_var("var_i")
|
||||
self.assertFalse(i.is_int())
|
||||
i.set_int(10)
|
||||
self.assertTrue(i.is_int())
|
||||
self.assertEqual(10, i.get_int())
|
||||
|
||||
for _ in xrange(10):
|
||||
scoped_function(__new_scope__)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue