|
|
|
|
@ -192,6 +192,33 @@ def get_inference_program(target_vars, main_program=None):
|
|
|
|
|
return inference_program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepend_feed_ops(inference_program, feeded_var_names):
|
|
|
|
|
global_block = inference_program.global_block()
|
|
|
|
|
feed_var = global_block.create_var(
|
|
|
|
|
name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True)
|
|
|
|
|
|
|
|
|
|
for i, name in enumerate(feeded_var_names):
|
|
|
|
|
out = global_block.var(name)
|
|
|
|
|
global_block.prepend_op(
|
|
|
|
|
type='feed',
|
|
|
|
|
inputs={'X': [feed_var]},
|
|
|
|
|
outputs={'Out': [out]},
|
|
|
|
|
attrs={'col': i})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def append_fetch_ops(inference_program, fetch_var_names):
|
|
|
|
|
global_block = inference_program.global_block()
|
|
|
|
|
fetch_var = global_block.create_var(
|
|
|
|
|
name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True)
|
|
|
|
|
|
|
|
|
|
for i, name in enumerate(fetch_var_names):
|
|
|
|
|
global_block.append_op(
|
|
|
|
|
type='fetch',
|
|
|
|
|
inputs={'X': [name]},
|
|
|
|
|
outputs={'Out': [fetch_var]},
|
|
|
|
|
attrs={'col': i})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_inference_model(dirname,
|
|
|
|
|
feeded_var_names,
|
|
|
|
|
target_vars,
|
|
|
|
|
@ -244,27 +271,8 @@ def save_inference_model(dirname,
|
|
|
|
|
|
|
|
|
|
# Save only programDesc of inference_program in binary format
|
|
|
|
|
# in another file: __model__.dat
|
|
|
|
|
global_block = inference_program.global_block()
|
|
|
|
|
feed_var = global_block.create_var(
|
|
|
|
|
name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True)
|
|
|
|
|
|
|
|
|
|
for i, name in enumerate(feeded_var_names):
|
|
|
|
|
out = global_block.var(name)
|
|
|
|
|
global_block.prepend_op(
|
|
|
|
|
type='feed',
|
|
|
|
|
inputs={'X': [feed_var]},
|
|
|
|
|
outputs={'Out': [out]},
|
|
|
|
|
attrs={'col': i})
|
|
|
|
|
|
|
|
|
|
fetch_var = global_block.create_var(
|
|
|
|
|
name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True)
|
|
|
|
|
|
|
|
|
|
for i, name in enumerate(fetch_var_names):
|
|
|
|
|
global_block.append_op(
|
|
|
|
|
type='fetch',
|
|
|
|
|
inputs={'X': [name]},
|
|
|
|
|
outputs={'Out': [fetch_var]},
|
|
|
|
|
attrs={'col': i})
|
|
|
|
|
prepend_feed_ops(inference_program, feeded_var_names)
|
|
|
|
|
append_fetch_ops(inference_program, fetch_var_names)
|
|
|
|
|
|
|
|
|
|
with open(model_file_name + ".dat", "wb") as fp:
|
|
|
|
|
fp.write(inference_program.desc.serialize_to_string())
|
|
|
|
|
|