rename 'feed_dict' in ParallelExecutor.run() to 'feed'

wangkuiyi-patch-2
JiayiFeng 7 years ago
parent a78b92854d
commit 22df230ee4

@ -61,8 +61,8 @@ class ParallelExecutor(object):
main_program=test_program,
share_vars_from=train_exe)
train_loss, = train_exe.run([loss.name], feed_dict=feed_dict)
test_loss, = test_exe.run([loss.name], feed_dict=feed_dict)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
test_loss, = test_exe.run([loss.name], feed=feed_dict)
"""
self._places = []
@ -123,22 +123,23 @@ class ParallelExecutor(object):
allow_op_delay)
self.scope = scope
def run(self, fetch_list, feed_dict={}):
def run(self, fetch_list, feed={}, feed_dict={}):
"""
:param fetch_list: A list of variable names that will be fetched.
:param feed_dict: A dict mapping for feed variable name to LoDTensor
:param feed: A dict mapping for feed variable name to LoDTensor
or numpy array.
:return: fetched value list.
"""
if not isinstance(feed_dict, dict):
raise TypeError("feed_dict should be a dict")
feed = feed_dict
if not isinstance(feed, dict):
raise TypeError("feed should be a dict")
feed_tensor_dict = {}
for i, feed_name in enumerate(feed_dict):
feed_tensor = feed_dict[feed_name]
for i, feed_name in enumerate(feed):
feed_tensor = feed[feed_name]
if not isinstance(feed_tensor, core.LoDTensor):
feed_tensor = core.LoDTensor()
feed_tensor.set(feed_dict[feed_name], self._act_places[0])
feed_tensor.set(feed[feed_name], self._act_places[0])
feed_tensor_dict[feed_name] = feed_tensor
fetch_var_name = '@FETCHED_VAR_NAME@'

Loading…
Cancel
Save