|
|
|
@ -17,6 +17,7 @@ from __future__ import print_function
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import multiprocessing
|
|
|
|
|
import sys
|
|
|
|
|
import numpy as np
|
|
|
|
|
from .wrapped_decorator import signature_safe_contextmanager
|
|
|
|
|
import six
|
|
|
|
@ -627,6 +628,23 @@ class Executor(object):
|
|
|
|
|
|
|
|
|
|
list(numpy.array): fetch result according to fetch_list.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
return self._run_impl(
|
|
|
|
|
program=program,
|
|
|
|
|
feed=feed,
|
|
|
|
|
fetch_list=fetch_list,
|
|
|
|
|
feed_var_name=feed_var_name,
|
|
|
|
|
fetch_var_name=fetch_var_name,
|
|
|
|
|
scope=scope,
|
|
|
|
|
return_numpy=return_numpy,
|
|
|
|
|
use_program_cache=use_program_cache)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
if not isinstance(e, core.EOFException):
|
|
|
|
|
print("An exception was thrown!\n {}".format(str(e)))
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
def _run_impl(self, program, feed, fetch_list, feed_var_name,
|
|
|
|
|
fetch_var_name, scope, return_numpy, use_program_cache):
|
|
|
|
|
|
|
|
|
|
if self._closed:
|
|
|
|
|
raise RuntimeError("Attempted to use a closed Executor")
|
|
|
|
@ -639,7 +657,7 @@ class Executor(object):
|
|
|
|
|
compiled = isinstance(program, compiler.CompiledProgram)
|
|
|
|
|
# For backward compatibility, run directly.
|
|
|
|
|
if not compiled:
|
|
|
|
|
return self._run(
|
|
|
|
|
return self._run_program(
|
|
|
|
|
program,
|
|
|
|
|
self._default_executor,
|
|
|
|
|
feed=feed,
|
|
|
|
@ -672,7 +690,7 @@ class Executor(object):
|
|
|
|
|
# TODO(panyx0718): executor should be able to run graph.
|
|
|
|
|
assert program._program, "CompiledProgram is compiled from graph, can only run with_data_parallel."
|
|
|
|
|
# use_program_cache is not valid with CompiledProgram
|
|
|
|
|
return self._run(
|
|
|
|
|
return self._run_program(
|
|
|
|
|
program._program,
|
|
|
|
|
self._default_executor,
|
|
|
|
|
feed=feed,
|
|
|
|
@ -683,8 +701,8 @@ class Executor(object):
|
|
|
|
|
return_numpy=return_numpy,
|
|
|
|
|
use_program_cache=False)
|
|
|
|
|
|
|
|
|
|
def _run(self, program, exe, feed, fetch_list, feed_var_name,
|
|
|
|
|
fetch_var_name, scope, return_numpy, use_program_cache):
|
|
|
|
|
def _run_program(self, program, exe, feed, fetch_list, feed_var_name,
|
|
|
|
|
fetch_var_name, scope, return_numpy, use_program_cache):
|
|
|
|
|
|
|
|
|
|
if feed is None:
|
|
|
|
|
feed = {}
|
|
|
|
|