|
|
|
@ -13,7 +13,6 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import cPickle as pickle
|
|
|
|
|
|
|
|
|
|
from paddle.v2.fluid.evaluator import Evaluator
|
|
|
|
|
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
|
|
|
|
@ -200,12 +199,16 @@ def get_inference_program(target_vars, main_program=None):
|
|
|
|
|
return inference_program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepend_feed_ops(inference_program, feeded_var_names):
|
|
|
|
|
def prepend_feed_ops(inference_program,
|
|
|
|
|
feed_target_names,
|
|
|
|
|
feed_holder_name='feed'):
|
|
|
|
|
global_block = inference_program.global_block()
|
|
|
|
|
feed_var = global_block.create_var(
|
|
|
|
|
name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True)
|
|
|
|
|
name=feed_holder_name,
|
|
|
|
|
type=core.VarDesc.VarType.FEED_MINIBATCH,
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
for i, name in enumerate(feeded_var_names):
|
|
|
|
|
for i, name in enumerate(feed_target_names):
|
|
|
|
|
out = global_block.var(name)
|
|
|
|
|
global_block.prepend_op(
|
|
|
|
|
type='feed',
|
|
|
|
@ -214,12 +217,16 @@ def prepend_feed_ops(inference_program, feeded_var_names):
|
|
|
|
|
attrs={'col': i})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def append_fetch_ops(inference_program, fetch_var_names):
|
|
|
|
|
def append_fetch_ops(inference_program,
|
|
|
|
|
fetch_target_names,
|
|
|
|
|
fetch_holder_name='fetch'):
|
|
|
|
|
global_block = inference_program.global_block()
|
|
|
|
|
fetch_var = global_block.create_var(
|
|
|
|
|
name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True)
|
|
|
|
|
name=fetch_holder_name,
|
|
|
|
|
type=core.VarDesc.VarType.FETCH_LIST,
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
for i, name in enumerate(fetch_var_names):
|
|
|
|
|
for i, name in enumerate(fetch_target_names):
|
|
|
|
|
global_block.append_op(
|
|
|
|
|
type='fetch',
|
|
|
|
|
inputs={'X': [name]},
|
|
|
|
@ -269,21 +276,12 @@ def save_inference_model(dirname,
|
|
|
|
|
inference_program = pruned_program.inference_optimize()
|
|
|
|
|
fetch_var_names = [v.name for v in target_vars]
|
|
|
|
|
|
|
|
|
|
model_file_name = dirname + "/__model__"
|
|
|
|
|
with open(model_file_name, "w") as f:
|
|
|
|
|
pickle.dump({
|
|
|
|
|
"program_desc_str": inference_program.desc.serialize_to_string(),
|
|
|
|
|
"feed_var_names": feeded_var_names,
|
|
|
|
|
"fetch_var_names": fetch_var_names
|
|
|
|
|
}, f, -1)
|
|
|
|
|
|
|
|
|
|
prepend_feed_ops(inference_program, feeded_var_names)
|
|
|
|
|
append_fetch_ops(inference_program, fetch_var_names)
|
|
|
|
|
|
|
|
|
|
# Save only programDesc of inference_program in binary format
|
|
|
|
|
# in another file: __model__.dat
|
|
|
|
|
with open(model_file_name + ".dat", "wb") as fp:
|
|
|
|
|
fp.write(inference_program.desc.serialize_to_string())
|
|
|
|
|
model_file_name = dirname + "/__model__"
|
|
|
|
|
with open(model_file_name, "wb") as f:
|
|
|
|
|
f.write(inference_program.desc.serialize_to_string())
|
|
|
|
|
|
|
|
|
|
save_params(executor, dirname, main_program)
|
|
|
|
|
|
|
|
|
@ -306,6 +304,24 @@ def load_persistables_if_exist(executor, dirname, main_program=None):
|
|
|
|
|
predicate=_is_presistable_and_exist_)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_feed_targets_names(program):
|
|
|
|
|
feed_targets_names = []
|
|
|
|
|
global_block = program.global_block()
|
|
|
|
|
for op in global_block.ops:
|
|
|
|
|
if op.desc.type() == 'feed':
|
|
|
|
|
feed_targets_names.insert(0, op.desc.output('Out')[0])
|
|
|
|
|
return feed_targets_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_fetch_targets_names(program):
|
|
|
|
|
fetch_targets_names = []
|
|
|
|
|
global_block = program.global_block()
|
|
|
|
|
for op in global_block.ops:
|
|
|
|
|
if op.desc.type() == 'fetch':
|
|
|
|
|
fetch_targets_names.append(op.desc.input('X')[0])
|
|
|
|
|
return fetch_targets_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_inference_model(dirname, executor):
|
|
|
|
|
"""
|
|
|
|
|
Load inference model from a directory
|
|
|
|
@ -313,24 +329,28 @@ def load_inference_model(dirname, executor):
|
|
|
|
|
:param dirname: directory path
|
|
|
|
|
:param executor: executor that load inference model
|
|
|
|
|
|
|
|
|
|
:return: [program, feed_var_names, fetch_var_names]
|
|
|
|
|
:return: [program, feed_target_names, fetch_targets]
|
|
|
|
|
program: program especially for inference.
|
|
|
|
|
feeded_var_names: Names of variables that need to feed data
|
|
|
|
|
fetch_vars: Variables from which we can get inference results.
|
|
|
|
|
feed_target_names: Names of variables that need to feed data
|
|
|
|
|
fetch_targets: Variables from which we can get inference results.
|
|
|
|
|
"""
|
|
|
|
|
if not os.path.isdir(dirname):
|
|
|
|
|
raise ValueError("There is no directory named '%s'", dirname)
|
|
|
|
|
|
|
|
|
|
model_file_name = dirname + "/__model__"
|
|
|
|
|
model = pickle.load(open(model_file_name, "r"))
|
|
|
|
|
program_desc_str = model["program_desc_str"]
|
|
|
|
|
feed_var_names = model["feed_var_names"]
|
|
|
|
|
fetch_var_names = model["fetch_var_names"]
|
|
|
|
|
with open(model_file_name, "rb") as f:
|
|
|
|
|
program_desc_str = f.read()
|
|
|
|
|
|
|
|
|
|
program = Program.parse_from_string(program_desc_str)
|
|
|
|
|
load_persistables_if_exist(executor, dirname, program)
|
|
|
|
|
fetch_vars = [program.global_block().var(name) for name in fetch_var_names]
|
|
|
|
|
|
|
|
|
|
return [program, feed_var_names, fetch_vars]
|
|
|
|
|
feed_target_names = get_feed_targets_names(program)
|
|
|
|
|
fetch_target_names = get_fetch_targets_names(program)
|
|
|
|
|
fetch_targets = [
|
|
|
|
|
program.global_block().var(name) for name in fetch_target_names
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return [program, feed_target_names, fetch_targets]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_parameter_value(para, executor):
|
|
|
|
|