|
|
|
@ -20,6 +20,7 @@ import six
|
|
|
|
|
import textwrap
|
|
|
|
|
import threading
|
|
|
|
|
import warnings
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
|
|
from paddle.fluid import framework
|
|
|
|
|
from paddle.fluid import core, executor
|
|
|
|
@ -28,6 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStat
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check
|
|
|
|
|
from paddle.fluid.framework import in_dygraph_mode
|
|
|
|
|
from paddle.fluid.data_feeder import check_type
|
|
|
|
|
|
|
|
|
|
__all__ = ['ProgramTranslator', 'convert_function_with_cache']
|
|
|
|
|
|
|
|
|
@ -261,19 +263,20 @@ class ProgramTranslator(object):
|
|
|
|
|
else:
|
|
|
|
|
self._exe = exe
|
|
|
|
|
self._program_cache = ProgramCache()
|
|
|
|
|
self._optimizer_info = None
|
|
|
|
|
self._optimizer = None
|
|
|
|
|
self._already_minimized = False
|
|
|
|
|
self._loss_name = None
|
|
|
|
|
# Once main_program is changed, should run startup_program.
|
|
|
|
|
self._need_startup = True
|
|
|
|
|
|
|
|
|
|
def get_output(self, dygraph_func, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Returns the output tensors for dygraph function and its arguments
|
|
|
|
|
Return the output tensors for dygraph function and its arguments
|
|
|
|
|
"""
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The ProgramTranslator.get_output doesn't work in dygraph "
|
|
|
|
|
"mode. We will just return dygraph output. Use the it in "
|
|
|
|
|
"mode. We will just return dygraph output. Use it in "
|
|
|
|
|
"static mode if you would like to translate to static graph.")
|
|
|
|
|
return dygraph_func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -286,12 +289,12 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def get_func(self, dygraph_func):
|
|
|
|
|
"""
|
|
|
|
|
Returns the translated static function from dygraph function
|
|
|
|
|
Return the translated static function from dygraph function
|
|
|
|
|
"""
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The ProgramTranslator.get_func doesn't work in dygraph "
|
|
|
|
|
"mode. We will just return dygraph function. Use the it in "
|
|
|
|
|
"mode. We will just return dygraph function. Use it in "
|
|
|
|
|
"static mode if you would like to translate to static graph.")
|
|
|
|
|
return dygraph_func
|
|
|
|
|
static_func = convert_function_with_cache(dygraph_func)
|
|
|
|
@ -299,7 +302,7 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def get_program(self, dygraph_func, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Returns the translated static program and input/output variables from
|
|
|
|
|
Return the translated static program and input/output variables from
|
|
|
|
|
dygraph function.
|
|
|
|
|
"""
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
@ -315,7 +318,7 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def get_code(self, dygraph_func):
|
|
|
|
|
"""
|
|
|
|
|
Returns the translated static function code from dygraph code
|
|
|
|
|
Return the translated static function code from dygraph code
|
|
|
|
|
"""
|
|
|
|
|
# Get AST from dygraph function
|
|
|
|
|
raw_code = inspect.getsource(dygraph_func)
|
|
|
|
@ -332,7 +335,7 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def run(self, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Executes main_program and returns output Tensors.
|
|
|
|
|
Execute main_program and returns output Tensors.
|
|
|
|
|
"""
|
|
|
|
|
feed_dict, fetch_list = self._prepare(args)
|
|
|
|
|
|
|
|
|
@ -343,18 +346,18 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
def set_optimizer(self, optimizer, loss_name):
|
|
|
|
|
def set_optimizer(self, optimizer, index_of_loss=0):
|
|
|
|
|
"""
|
|
|
|
|
Supports to set or update the optimizer used to minimize loss.
|
|
|
|
|
Support to set or update the optimizer used to minimize loss.
|
|
|
|
|
"""
|
|
|
|
|
check_type(index_of_loss, "index_of_loss", int,
|
|
|
|
|
"ProgramTranslator.set_optimizer")
|
|
|
|
|
self._check_cache_valid()
|
|
|
|
|
self._optimizer = optimizer
|
|
|
|
|
|
|
|
|
|
if not isinstance(loss_name, six.string_types):
|
|
|
|
|
if self._optimizer and self._loss_name:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Type of input loss_name should type(str), but received {}.".
|
|
|
|
|
format(type(loss_name)))
|
|
|
|
|
self._loss_name = loss_name
|
|
|
|
|
"{} for {} has already been set before. Please confirm not to call `set_optimizer` in for loop. ".
|
|
|
|
|
format(self._optimizer, self._loss_name))
|
|
|
|
|
self._optimizer_info = (optimizer, index_of_loss)
|
|
|
|
|
|
|
|
|
|
def save_inference_model(self, dirname, feed=None, fetch=None):
|
|
|
|
|
"""
|
|
|
|
@ -377,16 +380,16 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def _prepare(self, args):
|
|
|
|
|
"""
|
|
|
|
|
Prepares with feed_dict, fetch_list, optimizer and initialize vars
|
|
|
|
|
Prepare with feed_dict, fetch_list, optimizer and initialize vars
|
|
|
|
|
by running startup_program.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Updates batch_data for feed_dict
|
|
|
|
|
# Update batch_data for feed_dict
|
|
|
|
|
feed_dict = self._update_batch_data(args)
|
|
|
|
|
fetch_list = self._program_cache.outputs
|
|
|
|
|
|
|
|
|
|
# Adds optimizer if needed.
|
|
|
|
|
if self._optimizer and not self._already_minimized:
|
|
|
|
|
# Add optimizer if needed.
|
|
|
|
|
if self._optimizer_info and self._optimizer is None:
|
|
|
|
|
self._add_optimizer()
|
|
|
|
|
|
|
|
|
|
if self._need_startup:
|
|
|
|
@ -397,7 +400,7 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def _check_cache_valid(self):
|
|
|
|
|
"""
|
|
|
|
|
Checks whether the current program is consistent with `default_main_program`.
|
|
|
|
|
Check whether the current program is consistent with `default_main_program`.
|
|
|
|
|
In some models and unittest, program will be switched frequently by `program_guard`.
|
|
|
|
|
If does, the cached program and other properties are not available and should be reset.
|
|
|
|
|
"""
|
|
|
|
@ -408,7 +411,7 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def _update_batch_data(self, args):
|
|
|
|
|
"""
|
|
|
|
|
Updates cached batch data while training program.
|
|
|
|
|
Update cached batch data while training program.
|
|
|
|
|
"""
|
|
|
|
|
feed_name_to_idx = self._program_cache.feed_name_to_idx
|
|
|
|
|
feed_vars = self._program_cache.inputs
|
|
|
|
@ -421,27 +424,40 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
def _add_optimizer(self):
|
|
|
|
|
"""
|
|
|
|
|
Supports to set or update the optimizer used to minimize loss.
|
|
|
|
|
Support to set or update the optimizer used to minimize loss.
|
|
|
|
|
"""
|
|
|
|
|
optimizer, index_of_loss = self._optimizer_info
|
|
|
|
|
|
|
|
|
|
outputs = self._program_cache.outputs
|
|
|
|
|
outputs = [outputs] if not isinstance(outputs,
|
|
|
|
|
(list, tuple)) else outputs
|
|
|
|
|
|
|
|
|
|
assert abs(index_of_loss) < len(outputs), \
|
|
|
|
|
"index_of_loss: {} shall not exceed the length of outputs: {}.".format(
|
|
|
|
|
index_of_loss, len(outputs))
|
|
|
|
|
|
|
|
|
|
loss_var = outputs[index_of_loss]
|
|
|
|
|
check_type(loss_var, "loss_var", framework.Variable,
|
|
|
|
|
"ProgramTranslator._add_optimizer")
|
|
|
|
|
|
|
|
|
|
main_program = self._program_cache.main_program
|
|
|
|
|
startup_program = self._program_cache.startup_program
|
|
|
|
|
all_vars = main_program.block(0).vars
|
|
|
|
|
loss_var = all_vars.get(self._loss_name, None)
|
|
|
|
|
|
|
|
|
|
if loss_var is None:
|
|
|
|
|
if all_vars.get(loss_var.name, None) is None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Can't find {} in main_program, please confirm whether the loss input is correct"
|
|
|
|
|
.format(self._loss_name))
|
|
|
|
|
# Adds optimizer to minimize loss
|
|
|
|
|
"Can't find {} in main_program, please confirm whether the input loss is correct."
|
|
|
|
|
.format(loss_var.name))
|
|
|
|
|
# Add optimizer to minimize loss
|
|
|
|
|
with framework.program_guard(main_program, startup_program):
|
|
|
|
|
self._optimizer.minimize(loss_var)
|
|
|
|
|
optimizer.minimize(loss_var)
|
|
|
|
|
|
|
|
|
|
# Avoids to set optimizer repeatedly.
|
|
|
|
|
self._already_minimized = True
|
|
|
|
|
self._optimizer = optimizer
|
|
|
|
|
self._loss_name = loss_var.name
|
|
|
|
|
|
|
|
|
|
def get_program_cache(self):
|
|
|
|
|
"""
|
|
|
|
|
Returns the ProgramCache instance.
|
|
|
|
|
Return the ProgramCache instance.
|
|
|
|
|
"""
|
|
|
|
|
self._check_cache_valid()
|
|
|
|
|
return self._program_cache
|
|
|
|
|