|
|
|
@ -637,8 +637,8 @@ def save_inference_model(dirname,
|
|
|
|
|
if isinstance(target_vars, Variable):
|
|
|
|
|
target_vars = [target_vars]
|
|
|
|
|
elif export_for_deployment:
|
|
|
|
|
if not (bool(target_vars) and all(
|
|
|
|
|
isinstance(var, Variable) for var in target_vars)):
|
|
|
|
|
if not (bool(target_vars) and
|
|
|
|
|
all(isinstance(var, Variable) for var in target_vars)):
|
|
|
|
|
raise ValueError("'target_vars' should be a list of Variable.")
|
|
|
|
|
|
|
|
|
|
if main_program is None:
|
|
|
|
@ -667,10 +667,15 @@ def save_inference_model(dirname,
|
|
|
|
|
if export_for_deployment:
|
|
|
|
|
main_program = main_program.clone()
|
|
|
|
|
global_block = main_program.global_block()
|
|
|
|
|
need_to_remove_op_index = []
|
|
|
|
|
for i, op in enumerate(global_block.ops):
|
|
|
|
|
op.desc.set_is_target(False)
|
|
|
|
|
if op.type == "feed" or op.type == "fetch":
|
|
|
|
|
global_block._remove_op(i)
|
|
|
|
|
need_to_remove_op_index.append(i)
|
|
|
|
|
|
|
|
|
|
for index in need_to_remove_op_index[::-1]:
|
|
|
|
|
global_block._remove_op(index)
|
|
|
|
|
|
|
|
|
|
main_program.desc.flush()
|
|
|
|
|
|
|
|
|
|
main_program = main_program._prune(targets=target_vars)
|
|
|
|
|