|
|
@ -336,18 +336,20 @@ def save_inference_model(dirname,
|
|
|
|
|
|
|
|
|
|
|
|
if main_program is None:
|
|
|
|
if main_program is None:
|
|
|
|
main_program = default_main_program()
|
|
|
|
main_program = default_main_program()
|
|
|
|
|
|
|
|
copy_program = main_program
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(dirname):
|
|
|
|
if not os.path.isdir(dirname):
|
|
|
|
os.makedirs(dirname)
|
|
|
|
os.makedirs(dirname)
|
|
|
|
|
|
|
|
|
|
|
|
# Clear the is_target information and remove the existed feed and fetch op
|
|
|
|
# Clear the is_target information and remove the existed feed and fetch op
|
|
|
|
global_block = main_program.global_block()
|
|
|
|
global_block = copy_program.global_block()
|
|
|
|
for i, op in enumerate(global_block.ops):
|
|
|
|
for i, op in enumerate(global_block.ops):
|
|
|
|
op.desc.set_is_target(False)
|
|
|
|
op.desc.set_is_target(False)
|
|
|
|
if op.type == "feed" or op.type == "fetch":
|
|
|
|
if op.type == "feed" or op.type == "fetch":
|
|
|
|
global_block.remove_op(i)
|
|
|
|
global_block.remove_op(i)
|
|
|
|
|
|
|
|
copy_program.desc.flush()
|
|
|
|
|
|
|
|
|
|
|
|
pruned_program = main_program.prune(targets=target_vars)
|
|
|
|
pruned_program = copy_program.prune(targets=target_vars)
|
|
|
|
inference_program = pruned_program.inference_optimize()
|
|
|
|
inference_program = pruned_program.inference_optimize()
|
|
|
|
fetch_var_names = [v.name for v in target_vars]
|
|
|
|
fetch_var_names = [v.name for v in target_vars]
|
|
|
|
|
|
|
|
|
|
|
|