|
|
@ -46,6 +46,7 @@ __all__ = [
|
|
|
|
'deserialize_program',
|
|
|
|
'deserialize_program',
|
|
|
|
'deserialize_persistables',
|
|
|
|
'deserialize_persistables',
|
|
|
|
'load_from_file',
|
|
|
|
'load_from_file',
|
|
|
|
|
|
|
|
'normalize_program',
|
|
|
|
]
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
_logger = get_logger(
|
|
|
|
_logger = get_logger(
|
|
|
@ -127,10 +128,64 @@ def _clone_var_in_block(block, var):
|
|
|
|
persistable=True)
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_program(program, feed_vars, fetch_vars):
|
|
|
|
def normalize_program(program, feed_vars, fetch_vars):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
optimize program according feed_vars and fetch_vars.
|
|
|
|
:api_attr: Static Graph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Normalize/Optimize a program according to feed_vars and fetch_vars.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
program(Program): Specify a program you want to optimize.
|
|
|
|
|
|
|
|
feed_vars(Variable | list[Variable]): Variables needed by inference.
|
|
|
|
|
|
|
|
fetch_vars(Variable | list[Variable]): Variables returned by inference.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
Program: Normalized/Optimized program.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
|
|
TypeError: If `program` is not a Program, an exception is thrown.
|
|
|
|
|
|
|
|
TypeError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown.
|
|
|
|
|
|
|
|
TypeError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
path_prefix = "./infer_model"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# User defined network, here a softmax regession example
|
|
|
|
|
|
|
|
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
|
|
|
|
|
|
|
|
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
|
|
|
|
|
|
|
|
predict = paddle.static.nn.fc(image, 10, activation='softmax')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = paddle.nn.functional.cross_entropy(predict, label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exe = paddle.static.Executor(paddle.CPUPlace())
|
|
|
|
|
|
|
|
exe.run(paddle.static.default_startup_program())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# normalize main program.
|
|
|
|
|
|
|
|
program = default_main_program()
|
|
|
|
|
|
|
|
normalized_program = paddle.static.normalize_program(program, [image], [predict])
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
if not isinstance(program, Program):
|
|
|
|
|
|
|
|
raise TypeError(
|
|
|
|
|
|
|
|
"program type must be `fluid.Program`, but received `%s`" %
|
|
|
|
|
|
|
|
type(program))
|
|
|
|
|
|
|
|
if not isinstance(feed_vars, list):
|
|
|
|
|
|
|
|
feed_vars = [feed_vars]
|
|
|
|
|
|
|
|
if not all(isinstance(v, Variable) for v in feed_vars):
|
|
|
|
|
|
|
|
raise TypeError(
|
|
|
|
|
|
|
|
"feed_vars type must be a Variable or a list of Variable.")
|
|
|
|
|
|
|
|
if not isinstance(fetch_vars, list):
|
|
|
|
|
|
|
|
fetch_vars = [fetch_vars]
|
|
|
|
|
|
|
|
if not all(isinstance(v, Variable) for v in fetch_vars):
|
|
|
|
|
|
|
|
raise TypeError(
|
|
|
|
|
|
|
|
"fetch_vars type must be a Variable or a list of Variable.")
|
|
|
|
|
|
|
|
|
|
|
|
# remind users to set auc_states to 0 if auc op were found.
|
|
|
|
# remind users to set auc_states to 0 if auc op were found.
|
|
|
|
for op in program.global_block().ops:
|
|
|
|
for op in program.global_block().ops:
|
|
|
|
# clear device of Op
|
|
|
|
# clear device of Op
|
|
|
@ -255,7 +310,7 @@ def serialize_program(feed_vars, fetch_vars, **kwargs):
|
|
|
|
_check_vars('fetch_vars', fetch_vars)
|
|
|
|
_check_vars('fetch_vars', fetch_vars)
|
|
|
|
|
|
|
|
|
|
|
|
program = _get_valid_program(kwargs.get('program', None))
|
|
|
|
program = _get_valid_program(kwargs.get('program', None))
|
|
|
|
program = _normalize_program(program, feed_vars, fetch_vars)
|
|
|
|
program = normalize_program(program, feed_vars, fetch_vars)
|
|
|
|
return _serialize_program(program)
|
|
|
|
return _serialize_program(program)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -319,7 +374,7 @@ def serialize_persistables(feed_vars, fetch_vars, executor, **kwargs):
|
|
|
|
_check_vars('fetch_vars', fetch_vars)
|
|
|
|
_check_vars('fetch_vars', fetch_vars)
|
|
|
|
|
|
|
|
|
|
|
|
program = _get_valid_program(kwargs.get('program', None))
|
|
|
|
program = _get_valid_program(kwargs.get('program', None))
|
|
|
|
program = _normalize_program(program, feed_vars, fetch_vars)
|
|
|
|
program = normalize_program(program, feed_vars, fetch_vars)
|
|
|
|
return _serialize_persistables(program, executor)
|
|
|
|
return _serialize_persistables(program, executor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -463,7 +518,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
|
|
|
|
_check_vars('fetch_vars', fetch_vars)
|
|
|
|
_check_vars('fetch_vars', fetch_vars)
|
|
|
|
|
|
|
|
|
|
|
|
program = _get_valid_program(kwargs.get('program', None))
|
|
|
|
program = _get_valid_program(kwargs.get('program', None))
|
|
|
|
program = _normalize_program(program, feed_vars, fetch_vars)
|
|
|
|
program = normalize_program(program, feed_vars, fetch_vars)
|
|
|
|
# serialize and save program
|
|
|
|
# serialize and save program
|
|
|
|
program_bytes = _serialize_program(program)
|
|
|
|
program_bytes = _serialize_program(program)
|
|
|
|
save_to_file(model_path, program_bytes)
|
|
|
|
save_to_file(model_path, program_bytes)
|
|
|
|