|
|
|
@ -17,7 +17,9 @@ import contextlib
|
|
|
|
|
from framework import Program, default_main_program
|
|
|
|
|
from . import core
|
|
|
|
|
|
|
|
|
|
__all__ = ['Executor', 'global_scope', 'scope_guard', 'switch_scope']
|
|
|
|
|
__all__ = [
|
|
|
|
|
'Executor', 'global_scope', 'scope_guard', 'switch_scope', 'fetch_var'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
g_scope = core.Scope()
|
|
|
|
|
|
|
|
|
@ -80,12 +82,12 @@ def has_feed_operators(block, feed_targets, feed_holder_name):
|
|
|
|
|
Args:
|
|
|
|
|
block: a block instance (typically global block of a program)
|
|
|
|
|
feed_targets: a dictionary of {feed_target_name: feed_target_data}
|
|
|
|
|
feed_holder_name: the name of the variable that holds the data of
|
|
|
|
|
all feed targets. The type of this feed_holder variable is
|
|
|
|
|
feed_holder_name: the name of the variable that holds the data of
|
|
|
|
|
all feed targets. The type of this feed_holder variable is
|
|
|
|
|
FEED_MINIBATCH, which is essentially vector<LoDTensor>.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A boolean value that indicates whether a block has feed operators
|
|
|
|
|
A boolean value that indicates whether a block has feed operators
|
|
|
|
|
that match the info contained in feed_targets and feed_holder_name.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -108,7 +110,7 @@ def has_feed_operators(block, feed_targets, feed_holder_name):
|
|
|
|
|
|
|
|
|
|
def has_fetch_operators(block, fetch_targets, fetch_holder_name):
|
|
|
|
|
""" Check whether the block already has fetch operators.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Return false if the block does not have any fetch operators.
|
|
|
|
|
If some fetch operators have been appended to the block, check that
|
|
|
|
|
the info contained in these fetch operators matches the fetch_targets
|
|
|
|
@ -118,13 +120,13 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
|
|
|
|
|
Args:
|
|
|
|
|
block: a block instance (typically global block of a program)
|
|
|
|
|
fetch_targets: a dictionary of {fetch_target_name: fetch_target_data}
|
|
|
|
|
fetch_holder_name: the name of the variable that holds the data of
|
|
|
|
|
all fetch targets. The type of this fetch_holder variable is
|
|
|
|
|
FETCH_LIST, which is essentially vector<LoDTensor>.
|
|
|
|
|
fetch_holder_name: the name of the variable that holds the data of
|
|
|
|
|
all fetch targets. The type of this fetch_holder variable is
|
|
|
|
|
FETCH_LIST, which is essentially vector<LoDTensor>.
|
|
|
|
|
|
|
|
|
|
Return:
|
|
|
|
|
A boolean value that indicates whether a block has fetch operators
|
|
|
|
|
that match the info contained in fetch_targets and fetch_holder_name.
|
|
|
|
|
Return:
|
|
|
|
|
A boolean value that indicates whether a block has fetch operators
|
|
|
|
|
that match the info contained in fetch_targets and fetch_holder_name.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
fetch_count = 0
|
|
|
|
@ -146,6 +148,30 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
|
|
|
|
|
return fetch_count > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fetch_var(name, scope=None, return_numpy=True):
|
|
|
|
|
"""
|
|
|
|
|
Fetch the value of the variable with the given name from the given scope
|
|
|
|
|
Args:
|
|
|
|
|
name(str): name of the variable
|
|
|
|
|
scope(core.Scope|None): scope object.
|
|
|
|
|
If None, global_scope() will be used.
|
|
|
|
|
return_numpy(bool): whether convert the tensor to numpy.ndarray
|
|
|
|
|
Returns:
|
|
|
|
|
LodTensor|numpy.ndarray
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(name, str)
|
|
|
|
|
if scope is None:
|
|
|
|
|
scope = global_scope()
|
|
|
|
|
assert isinstance(scope, core.Scope)
|
|
|
|
|
|
|
|
|
|
var = global_scope().find_var(name)
|
|
|
|
|
assert var is not None, "Cannot find '%s' in scope." % name
|
|
|
|
|
tensor = var.get_tensor()
|
|
|
|
|
if return_numpy:
|
|
|
|
|
tensor = as_numpy(tensor)
|
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Executor(object):
|
|
|
|
|
def __init__(self, places):
|
|
|
|
|
if not isinstance(places, list) and not isinstance(places, tuple):
|
|
|
|
|