|
|
|
@ -16,11 +16,9 @@ from __future__ import print_function
|
|
|
|
|
import gast
|
|
|
|
|
import inspect
|
|
|
|
|
import numpy
|
|
|
|
|
import six
|
|
|
|
|
import textwrap
|
|
|
|
|
import threading
|
|
|
|
|
import warnings
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
|
|
from paddle.fluid import framework
|
|
|
|
|
from paddle.fluid import core, executor
|
|
|
|
@ -75,7 +73,7 @@ _FUNCTION_CACHE = FunctionCache()
|
|
|
|
|
|
|
|
|
|
def convert_function_with_cache(dygraph_func):
|
|
|
|
|
"""
|
|
|
|
|
Transform function of dygraph into static function using the cache mechanism.
|
|
|
|
|
Transforms function of dygraph into static function using the cache mechanism.
|
|
|
|
|
"""
|
|
|
|
|
with _CACHE_LOCK:
|
|
|
|
|
static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func)
|
|
|
|
@ -106,9 +104,9 @@ class ProgramCache(object):
|
|
|
|
|
self._main_program = framework.default_main_program()
|
|
|
|
|
self._startup_program = framework.default_startup_program()
|
|
|
|
|
self._func_cache = FunctionCache()
|
|
|
|
|
self._feed_name_to_idx = {}
|
|
|
|
|
# Stores the entry function of Net or Model.
|
|
|
|
|
self._forward_func = None
|
|
|
|
|
self._feed_name_to_idx = {}
|
|
|
|
|
self._is_repeated = False
|
|
|
|
|
# Indicates whether the function call is still building program.
|
|
|
|
|
# Because user can call recursively when `Net` has sub class in
|
|
|
|
@ -117,10 +115,10 @@ class ProgramCache(object):
|
|
|
|
|
|
|
|
|
|
def build_program_and_return_output(self, dyfunc, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Executes the main_program with specialized inputs so that the program
|
|
|
|
|
is built. This method also return outputs of program as fetch_list
|
|
|
|
|
Builds the main_program with specialized inputs and returns outputs
|
|
|
|
|
of program as fetch_list.
|
|
|
|
|
"""
|
|
|
|
|
# Transfroms dygraph function into static functions and caches them.
|
|
|
|
|
# Transforms dygraph function into static function and caches it.
|
|
|
|
|
static_func = self._transform_or_cache_layers(dyfunc)
|
|
|
|
|
|
|
|
|
|
# 1. Adds `fluid.data` layers for input if needed
|
|
|
|
@ -144,15 +142,23 @@ class ProgramCache(object):
|
|
|
|
|
Transforms dygraph function into static function.
|
|
|
|
|
"""
|
|
|
|
|
static_func = self._func_cache.get_or_cache_func(dyfunc)
|
|
|
|
|
# self._forward_func is entry function of Net or Model.
|
|
|
|
|
# It can be called for multiple times, but layers from these functions
|
|
|
|
|
# call stack will be added into self._main_program only once.
|
|
|
|
|
# After that, cached program will be always returned by default.
|
|
|
|
|
if static_func == self._forward_func:
|
|
|
|
|
self._is_repeated = True
|
|
|
|
|
|
|
|
|
|
if self._forward_func is None:
|
|
|
|
|
self._forward_func = static_func
|
|
|
|
|
else:
|
|
|
|
|
# self._forward_func is entry function of Net or Model.
|
|
|
|
|
# It can be called for multiple times, but layers from these functions
|
|
|
|
|
# call stack will be added into self._main_program only once.
|
|
|
|
|
# After that, cached program will be always returned by default.
|
|
|
|
|
if static_func == self._forward_func:
|
|
|
|
|
self._is_repeated = True
|
|
|
|
|
# If a independent function is received after the build process
|
|
|
|
|
# has finished, feed layers should be reset.
|
|
|
|
|
# TODO(Aurelius84): Switch main_program without specifying program_guard.
|
|
|
|
|
elif not self._in_build_process:
|
|
|
|
|
self._inputs = []
|
|
|
|
|
self._is_repeated = False
|
|
|
|
|
self._forward_func = static_func
|
|
|
|
|
|
|
|
|
|
return static_func
|
|
|
|
|
|
|
|
|
@ -180,8 +186,7 @@ class ProgramCache(object):
|
|
|
|
|
Adds `fluid.data` if the input `numpy.ndarray` is converted into `Variable`
|
|
|
|
|
by `to_variable()`, it makes program to be executed dynamically.
|
|
|
|
|
"""
|
|
|
|
|
if not self._feed_name_to_idx:
|
|
|
|
|
self._feed_name_to_idx = self._get_name_to_idx(self._forward_func)
|
|
|
|
|
self._feed_name_to_idx = self._get_name_to_idx(self._forward_func)
|
|
|
|
|
with framework.program_guard(self._main_program, self._startup_program):
|
|
|
|
|
for feed_name, idx in self.feed_name_to_idx.items():
|
|
|
|
|
batch_data = args[idx]
|
|
|
|
@ -267,12 +272,12 @@ class ProgramTranslator(object):
|
|
|
|
|
self._optimizer_info = None
|
|
|
|
|
self._optimizer = None
|
|
|
|
|
self._loss_name = None
|
|
|
|
|
# Once main_program is changed, should run startup_program.
|
|
|
|
|
self._need_startup = True
|
|
|
|
|
# Once startup_program is changed, should run startup_program.
|
|
|
|
|
self._prev_startup = None
|
|
|
|
|
|
|
|
|
|
def get_output(self, dygraph_func, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Return the output tensors for dygraph function and its arguments
|
|
|
|
|
Returns the output tensors for dygraph function and its arguments
|
|
|
|
|
"""
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
warnings.warn(
|
|
|
|
@ -292,7 +297,7 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def get_func(self, dygraph_func):
|
|
|
|
|
"""
|
|
|
|
|
Return the translated static function from dygraph function
|
|
|
|
|
Returns the translated static function from dygraph function
|
|
|
|
|
"""
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
warnings.warn(
|
|
|
|
@ -305,7 +310,7 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def get_program(self, dygraph_func, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Return the translated static program and input/output variables from
|
|
|
|
|
Returns the translated static program and input/output variables from
|
|
|
|
|
dygraph function.
|
|
|
|
|
"""
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
@ -321,9 +326,9 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def get_code(self, dygraph_func):
|
|
|
|
|
"""
|
|
|
|
|
Return the translated static function code from dygraph code
|
|
|
|
|
Returns the translated static function code from dygraph code
|
|
|
|
|
"""
|
|
|
|
|
# Get AST from dygraph function
|
|
|
|
|
# Gets AST from dygraph function
|
|
|
|
|
raw_code = inspect.getsource(dygraph_func)
|
|
|
|
|
code = textwrap.dedent(raw_code)
|
|
|
|
|
root = gast.parse(code)
|
|
|
|
@ -338,7 +343,7 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def run(self, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Execute main_program and returns output Tensors.
|
|
|
|
|
Executes main_program and returns output Tensors.
|
|
|
|
|
"""
|
|
|
|
|
feed_dict, fetch_list = self._prepare(args)
|
|
|
|
|
|
|
|
|
@ -351,7 +356,7 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def set_optimizer(self, optimizer, index_of_loss=0):
|
|
|
|
|
"""
|
|
|
|
|
Support to set or update the optimizer used to minimize loss.
|
|
|
|
|
Supports to set or update the optimizer used to minimize loss.
|
|
|
|
|
"""
|
|
|
|
|
check_type(index_of_loss, "index_of_loss", int,
|
|
|
|
|
"ProgramTranslator.set_optimizer")
|
|
|
|
@ -364,7 +369,7 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def save_inference_model(self, dirname, feed=None, fetch=None):
|
|
|
|
|
"""
|
|
|
|
|
Save current model as the inference model.
|
|
|
|
|
Saves current model as the inference model.
|
|
|
|
|
"""
|
|
|
|
|
program_cache = self.get_program_cache()
|
|
|
|
|
if feed is None:
|
|
|
|
@ -383,27 +388,38 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def _prepare(self, args):
|
|
|
|
|
"""
|
|
|
|
|
Prepare with feed_dict, fetch_list, optimizer and initialize vars
|
|
|
|
|
Prepares with feed_dict, fetch_list, optimizer and initialize vars
|
|
|
|
|
by running startup_program.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Update batch_data for feed_dict
|
|
|
|
|
# Updates batch_data for feed_dict
|
|
|
|
|
feed_dict = self._update_batch_data(args)
|
|
|
|
|
fetch_list = self._program_cache.outputs
|
|
|
|
|
|
|
|
|
|
# Add optimizer if needed.
|
|
|
|
|
# Adds optimizer if needed.
|
|
|
|
|
if self._optimizer_info and self._optimizer is None:
|
|
|
|
|
self._add_optimizer()
|
|
|
|
|
|
|
|
|
|
if self._need_startup:
|
|
|
|
|
if self._need_startup():
|
|
|
|
|
self._exe.run(self.startup_program)
|
|
|
|
|
self._need_startup = False
|
|
|
|
|
self._prev_startup = self.startup_program
|
|
|
|
|
|
|
|
|
|
return feed_dict, fetch_list
|
|
|
|
|
|
|
|
|
|
def _need_startup(self):
|
|
|
|
|
"""
|
|
|
|
|
Determines whether needy to run startup_program.
|
|
|
|
|
"""
|
|
|
|
|
if self.startup_program != self._prev_startup:
|
|
|
|
|
check_type(self.startup_program, "startup_program",
|
|
|
|
|
framework.Program, "_need_startup")
|
|
|
|
|
return len(self.startup_program.global_block().ops) > 0
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def _check_cache_valid(self):
|
|
|
|
|
"""
|
|
|
|
|
Check whether the current program is consistent with `default_main_program`.
|
|
|
|
|
Checks whether the current program is consistent with `default_main_program`.
|
|
|
|
|
In some models and unittest, program will be switched frequently by `program_guard`.
|
|
|
|
|
If does, the cached program and other properties are not available and should be reset.
|
|
|
|
|
"""
|
|
|
|
@ -414,7 +430,7 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def _update_batch_data(self, args):
|
|
|
|
|
"""
|
|
|
|
|
Update cached batch data while training program.
|
|
|
|
|
Updates cached batch data while training program.
|
|
|
|
|
"""
|
|
|
|
|
feed_name_to_idx = self._program_cache.feed_name_to_idx
|
|
|
|
|
feed_vars = self._program_cache.inputs
|
|
|
|
@ -427,7 +443,7 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def _add_optimizer(self):
|
|
|
|
|
"""
|
|
|
|
|
Support to set or update the optimizer used to minimize loss.
|
|
|
|
|
Supports to set or update the optimizer used to minimize loss.
|
|
|
|
|
"""
|
|
|
|
|
optimizer, index_of_loss = self._optimizer_info
|
|
|
|
|
|
|
|
|
@ -451,7 +467,7 @@ class ProgramTranslator(object):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Can't find {} in main_program, please confirm whether the input loss is correct."
|
|
|
|
|
.format(loss_var.name))
|
|
|
|
|
# Add optimizer to minimize loss
|
|
|
|
|
# Adds optimizer to minimize loss
|
|
|
|
|
with framework.program_guard(main_program, startup_program):
|
|
|
|
|
optimizer.minimize(loss_var)
|
|
|
|
|
|
|
|
|
@ -460,7 +476,7 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def get_program_cache(self):
|
|
|
|
|
"""
|
|
|
|
|
Return the ProgramCache instance.
|
|
|
|
|
Returns the ProgramCache instance.
|
|
|
|
|
"""
|
|
|
|
|
self._check_cache_valid()
|
|
|
|
|
return self._program_cache
|
|
|
|
|