|
|
|
@ -340,6 +340,13 @@ def save_inference_model(dirname,
|
|
|
|
|
if not os.path.isdir(dirname):
|
|
|
|
|
os.makedirs(dirname)
|
|
|
|
|
|
|
|
|
|
# Clear the is_target information and remove the existed feed and fetch op
|
|
|
|
|
global_block = main_program.global_block()
|
|
|
|
|
for i, op in enumerate(global_block.ops):
|
|
|
|
|
op.desc.set_is_target(False)
|
|
|
|
|
if op.type == "feed" or op.type == "fetch":
|
|
|
|
|
global_block.remove_op(i)
|
|
|
|
|
|
|
|
|
|
pruned_program = main_program.prune(targets=target_vars)
|
|
|
|
|
inference_program = pruned_program.inference_optimize()
|
|
|
|
|
fetch_var_names = [v.name for v in target_vars]
|
|
|
|
@ -362,24 +369,6 @@ def save_inference_model(dirname,
|
|
|
|
|
save_persistables(executor, dirname, inference_program, params_filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_feed_targets_names(program):
|
|
|
|
|
feed_targets_names = []
|
|
|
|
|
global_block = program.global_block()
|
|
|
|
|
for op in global_block.ops:
|
|
|
|
|
if op.desc.type() == 'feed':
|
|
|
|
|
feed_targets_names.insert(0, op.desc.output('Out')[0])
|
|
|
|
|
return feed_targets_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_fetch_targets_names(program):
|
|
|
|
|
fetch_targets_names = []
|
|
|
|
|
global_block = program.global_block()
|
|
|
|
|
for op in global_block.ops:
|
|
|
|
|
if op.desc.type() == 'fetch':
|
|
|
|
|
fetch_targets_names.append(op.desc.input('X')[0])
|
|
|
|
|
return fetch_targets_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_inference_model(dirname,
|
|
|
|
|
executor,
|
|
|
|
|
model_filename=None,
|
|
|
|
@ -418,8 +407,8 @@ def load_inference_model(dirname,
|
|
|
|
|
program = Program.parse_from_string(program_desc_str)
|
|
|
|
|
load_persistables(executor, dirname, program, params_filename)
|
|
|
|
|
|
|
|
|
|
feed_target_names = get_feed_targets_names(program)
|
|
|
|
|
fetch_target_names = get_fetch_targets_names(program)
|
|
|
|
|
feed_target_names = program.desc.get_feed_target_names()
|
|
|
|
|
fetch_target_names = program.desc.get_fetch_target_names()
|
|
|
|
|
fetch_targets = [
|
|
|
|
|
program.global_block().var(name) for name in fetch_target_names
|
|
|
|
|
]
|
|
|
|
|