|
|
|
@ -18,9 +18,11 @@ import inspect
|
|
|
|
|
import logging
|
|
|
|
|
import textwrap
|
|
|
|
|
import threading
|
|
|
|
|
import collections
|
|
|
|
|
import numpy as np
|
|
|
|
|
from paddle.fluid import core
|
|
|
|
|
from paddle.fluid import core, scope_guard
|
|
|
|
|
from paddle.fluid import framework
|
|
|
|
|
from paddle.fluid import executor
|
|
|
|
|
from paddle.fluid import unique_name
|
|
|
|
|
from paddle.fluid.dygraph import layers
|
|
|
|
|
from paddle.fluid.dygraph.base import switch_to_static_graph
|
|
|
|
@ -92,10 +94,12 @@ class FunctionSpec(object):
|
|
|
|
|
return self._args and isinstance(self._args[0], layers.Layer)
|
|
|
|
|
|
|
|
|
|
def parameters(self, include_sublayer=True):
|
|
|
|
|
params = {}
|
|
|
|
|
params = collections.OrderedDict()
|
|
|
|
|
if self.is_method():
|
|
|
|
|
if include_sublayer:
|
|
|
|
|
params = self._args[0].parameters()
|
|
|
|
|
names = [p.name for p in params]
|
|
|
|
|
params = collections.OrderedDict(zip(names, params))
|
|
|
|
|
else:
|
|
|
|
|
params = self._args[0]._parameters
|
|
|
|
|
return params
|
|
|
|
@ -155,11 +159,11 @@ class ConcreteProgram(object):
|
|
|
|
|
parameters,
|
|
|
|
|
func,
|
|
|
|
|
main_program,
|
|
|
|
|
start_up=None):
|
|
|
|
|
startup_program=None):
|
|
|
|
|
self.inputs = inputs
|
|
|
|
|
self.outputs = outputs
|
|
|
|
|
self.main_program = main_program
|
|
|
|
|
self.startup_program = start_up
|
|
|
|
|
self.startup_program = startup_program
|
|
|
|
|
self.parameters = parameters
|
|
|
|
|
self.func_spec = func
|
|
|
|
|
|
|
|
|
@ -174,18 +178,20 @@ class ConcreteProgram(object):
|
|
|
|
|
dygaph_function = func_spec.dyfunc
|
|
|
|
|
static_func = convert_function_with_cache(dygaph_function)
|
|
|
|
|
|
|
|
|
|
main_program, start_up = framework.Program(), framework.Program()
|
|
|
|
|
|
|
|
|
|
# Synchronous random seed of program
|
|
|
|
|
main_program, startup_program = framework.Program(), framework.Program()
|
|
|
|
|
# Note: The random seed should be synchronized into cached program
|
|
|
|
|
# if set in `fluid.dygrap_guard` because some ops rely on it, such as
|
|
|
|
|
# `fluid.layers.dropout`.
|
|
|
|
|
main_program.random_seed = framework.default_main_program().random_seed
|
|
|
|
|
start_up.random_seed = framework.default_startup_program().random_seed
|
|
|
|
|
startup_program.random_seed = framework.default_startup_program(
|
|
|
|
|
).random_seed
|
|
|
|
|
|
|
|
|
|
with framework.program_guard(main_program, start_up):
|
|
|
|
|
with framework.program_guard(main_program, startup_program):
|
|
|
|
|
# 1. Adds `fluid.data` layers for input if needed
|
|
|
|
|
inputs = func_spec.to_static_inputs(main_program)
|
|
|
|
|
|
|
|
|
|
# 2. Gets all ParamBases in the function
|
|
|
|
|
all_parameters = func_spec.parameters()
|
|
|
|
|
all_parameters = list(func_spec.parameters().values())
|
|
|
|
|
|
|
|
|
|
# 3. Builds program only once and returns the output Variables.
|
|
|
|
|
with param_guard(func_spec.parameters(False)):
|
|
|
|
@ -199,7 +205,7 @@ class ConcreteProgram(object):
|
|
|
|
|
parameters=all_parameters,
|
|
|
|
|
func=dygaph_function,
|
|
|
|
|
main_program=main_program,
|
|
|
|
|
start_up=start_up)
|
|
|
|
|
startup_program=startup_program)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ProgramCache(object):
|
|
|
|
@ -208,7 +214,7 @@ class ProgramCache(object):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self._caches = {}
|
|
|
|
|
self._caches = collections.OrderedDict()
|
|
|
|
|
|
|
|
|
|
def _build_once(self, func_spec):
|
|
|
|
|
concrete_program = ConcreteProgram.from_func_spec(func_spec)
|
|
|
|
@ -223,6 +229,12 @@ class ProgramCache(object):
|
|
|
|
|
self._caches[item] = self._build_once(item)
|
|
|
|
|
return self._caches[item]
|
|
|
|
|
|
|
|
|
|
def last(self):
|
|
|
|
|
assert len(
|
|
|
|
|
self._caches) >= 1, "No valid cached program in ProgramCache."
|
|
|
|
|
key = next(reversed(self._caches.keys()))
|
|
|
|
|
return key, self._caches[key]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def synchronized(func):
|
|
|
|
|
func.__lock__ = threading.Lock()
|
|
|
|
@ -476,10 +488,20 @@ class ProgramTranslator(object):
|
|
|
|
|
|
|
|
|
|
func_spec = FunctionSpec(dygraph_func, args, kwargs)
|
|
|
|
|
concrete_program, _ = self._program_cache[func_spec]
|
|
|
|
|
# Note: concrete_program hold all input/output infos include non-Variable
|
|
|
|
|
input_vars = [
|
|
|
|
|
var for var in concrete_program.inputs
|
|
|
|
|
if isinstance(var, framework.Variable)
|
|
|
|
|
]
|
|
|
|
|
output_vars = [
|
|
|
|
|
var for var in concrete_program.outputs
|
|
|
|
|
if isinstance(var, framework.Variable)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return concrete_program.main_program, \
|
|
|
|
|
concrete_program.startup_program, \
|
|
|
|
|
concrete_program.inputs, \
|
|
|
|
|
concrete_program.outputs
|
|
|
|
|
input_vars, \
|
|
|
|
|
output_vars
|
|
|
|
|
|
|
|
|
|
def get_code(self, dygraph_func):
|
|
|
|
|
"""
|
|
|
|
@ -527,6 +549,96 @@ class ProgramTranslator(object):
|
|
|
|
|
source_code = ast_to_source_code(root_wrapper.node)
|
|
|
|
|
return source_code
|
|
|
|
|
|
|
|
|
|
def save_inference_model(self, dirname, feed=None, fetch=None):
|
|
|
|
|
"""
|
|
|
|
|
Saves current model as the inference model. It will prune the main_program
|
|
|
|
|
to build a new program especially for inference, and then save it and all
|
|
|
|
|
related parameters to given `dirname` . The saved inference model can be
|
|
|
|
|
loaded by `:ref:`api_fluid_io_load_inference_model` or `C++ inference APIs.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
dirname (str): the directory to save the inference model.
|
|
|
|
|
feed (list[int], optional): the input variable indices of the saved
|
|
|
|
|
inference model. If None, all input variables of the
|
|
|
|
|
ProgramTranslator would be the inputs of the saved inference
|
|
|
|
|
model. Default None.
|
|
|
|
|
fetch (list[int], optional): the output variable indices of the
|
|
|
|
|
saved inference model. If None, all output variables of the
|
|
|
|
|
TracedLayer object would be the outputs of the saved inference
|
|
|
|
|
model. Default None.
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
from paddle.fluid.dygraph import Linear
|
|
|
|
|
from paddle.fluid.dygraph import ProgramTranslator
|
|
|
|
|
|
|
|
|
|
class SimpleNet(fluid.dygraph.Layer):
|
|
|
|
|
def __init__(self, in_size, out_size):
|
|
|
|
|
super(SimpleNet, self).__init__()
|
|
|
|
|
self._linear = Linear(in_size, out_size)
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
y = self._linear(x)
|
|
|
|
|
z = self._linear(y)
|
|
|
|
|
loss = fluid.layers.mean(z)
|
|
|
|
|
return z, loss
|
|
|
|
|
|
|
|
|
|
with fluid.dygraph.guard(fluid.CPUPlace()):
|
|
|
|
|
net = SimpleNet(8, 8)
|
|
|
|
|
adam = fluid.optimizer.AdamOptimizer(learning_rate=0.1, parameter_list=net.parameters())
|
|
|
|
|
x = fluid.dygraph.to_variable(np.random.random((4, 8)).astype('float32'))
|
|
|
|
|
for i in range(10):
|
|
|
|
|
loss, out = net(x)
|
|
|
|
|
loss.backward()
|
|
|
|
|
adam.minimize(loss)
|
|
|
|
|
net.clear_gradients()
|
|
|
|
|
# Save inference model.
|
|
|
|
|
# Note that fetch=[0] means we set 'y' as the inference output.
|
|
|
|
|
prog_trans = ProgramTranslator()
|
|
|
|
|
prog_trans.save_inference_model("./dy2stat_infer_model", fetch=[0])
|
|
|
|
|
|
|
|
|
|
# In this example, the inference model will be pruned based on input (x) and
|
|
|
|
|
# output (y). The pruned inference program is going to be saved in the folder
|
|
|
|
|
# "./dy2stat_infer_model" and parameters are going to be saved in separate
|
|
|
|
|
# files in the folder.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def get_feed_fetch(var_list, partial_vars, return_name=False):
|
|
|
|
|
vars = [
|
|
|
|
|
var for var in var_list if isinstance(var, framework.Variable)
|
|
|
|
|
]
|
|
|
|
|
if partial_vars:
|
|
|
|
|
vars = [vars[idx] for idx in partial_vars]
|
|
|
|
|
if return_name:
|
|
|
|
|
vars = [var.name for var in vars]
|
|
|
|
|
|
|
|
|
|
return vars
|
|
|
|
|
|
|
|
|
|
func_spec, (concrete_program,
|
|
|
|
|
partial_layer) = self._program_cache.last()
|
|
|
|
|
# share paramBase data with parameter
|
|
|
|
|
scope = core.Scope()
|
|
|
|
|
for param_base in concrete_program.parameters:
|
|
|
|
|
param_tensor = scope.var(param_base.name).get_tensor()
|
|
|
|
|
src_tensor = param_base.value().get_tensor()
|
|
|
|
|
param_tensor._share_data_with(src_tensor)
|
|
|
|
|
|
|
|
|
|
feed_var_names = get_feed_fetch(concrete_program.inputs, feed, True)
|
|
|
|
|
fetch_vars = get_feed_fetch(concrete_program.outputs, fetch)
|
|
|
|
|
|
|
|
|
|
from paddle.fluid.io import save_inference_model
|
|
|
|
|
with scope_guard(scope):
|
|
|
|
|
save_inference_model(
|
|
|
|
|
dirname=dirname,
|
|
|
|
|
feeded_var_names=feed_var_names,
|
|
|
|
|
target_vars=fetch_vars,
|
|
|
|
|
executor=executor.Executor(framework._current_expected_place()),
|
|
|
|
|
main_program=concrete_program.main_program.clone())
|
|
|
|
|
|
|
|
|
|
def get_program_cache(self):
|
|
|
|
|
"""
|
|
|
|
|
Returns the ProgramCache instance. This method is used by PaddlePaddle
|
|
|
|
|