More general implementation.

emailweixu-patch-1
wanghaoshuang 7 years ago
parent 5ecbba46ce
commit a6a79c35c9

@ -186,12 +186,16 @@ def load_persistables(executor, dirname, main_program=None):
def get_inference_program(target_vars, main_program=None):
if main_program is None:
main_program = default_main_program()
if isinstance(target_vars, Evaluator):
target_vars = target_vars.states + target_vars.metrics
if not isinstance(target_vars, list):
target_vars = [target_vars]
pruned_program = main_program.prune(targets=target_vars)
vars = []
for var in target_vars:
if isinstance(var, Evaluator):
vars.append(var.states)
vars.append(var.metrics)
else:
vars.append(var)
pruned_program = main_program.prune(targets=vars)
inference_program = pruned_program.inference_optimize()
return inference_program

Loading…
Cancel
Save