|
|
|
@ -152,8 +152,10 @@ def fetch_var(name, scope=None, return_numpy=True):
|
|
|
|
|
"""
|
|
|
|
|
Fetch the value of the variable with the given name from the given scope
|
|
|
|
|
Args:
|
|
|
|
|
name(str): name of the variable
|
|
|
|
|
scope(core.Scope|None): scope object.
|
|
|
|
|
name(str): name of the variable. Typically, only persistable variables
|
|
|
|
|
can be found in the scope used for running the program.
|
|
|
|
|
scope(core.Scope|None): scope object. It should be the scope where
|
|
|
|
|
you pass to Executor.run() when running your program.
|
|
|
|
|
If None, global_scope() will be used.
|
|
|
|
|
return_numpy(bool): whether convert the tensor to numpy.ndarray
|
|
|
|
|
Returns:
|
|
|
|
@ -165,7 +167,10 @@ def fetch_var(name, scope=None, return_numpy=True):
|
|
|
|
|
assert isinstance(scope, core.Scope)
|
|
|
|
|
|
|
|
|
|
var = global_scope().find_var(name)
|
|
|
|
|
assert var is not None, "Cannot find '%s' in scope." % name
|
|
|
|
|
assert var is not None, (
|
|
|
|
|
"Cannot find " + name + " in scope. Perhaps you need to make the"
|
|
|
|
|
" variable persistable by using var.persistable = True in your"
|
|
|
|
|
" program.")
|
|
|
|
|
tensor = var.get_tensor()
|
|
|
|
|
if return_numpy:
|
|
|
|
|
tensor = as_numpy(tensor)
|
|
|
|
|