|
|
|
|
@ -666,11 +666,22 @@ def save_inference_model(dirname,
|
|
|
|
|
|
|
|
|
|
save_persistables(executor, dirname, inference_program, params_filename)
|
|
|
|
|
|
|
|
|
|
# if there is lookup table, the trainer 0 will notify all pserver to save.
|
|
|
|
|
if main_program._is_distributed and main_program._is_chief:
|
|
|
|
|
if main_program._distributed_lookup_table:
|
|
|
|
|
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
|
|
|
|
|
_save_lookup_tables_by_notify(
|
|
|
|
|
executor, lookup_table_filename,
|
|
|
|
|
main_program._distributed_lookup_table, main_program._endpoints)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_inference_model(dirname,
|
|
|
|
|
executor,
|
|
|
|
|
model_filename=None,
|
|
|
|
|
params_filename=None):
|
|
|
|
|
params_filename=None,
|
|
|
|
|
training_role=None,
|
|
|
|
|
role_id=None,
|
|
|
|
|
pserver_endpoints=None):
|
|
|
|
|
"""
|
|
|
|
|
Load inference model from a directory
|
|
|
|
|
|
|
|
|
|
@ -736,6 +747,12 @@ def load_inference_model(dirname,
|
|
|
|
|
program = Program.parse_from_string(program_desc_str)
|
|
|
|
|
load_persistables(executor, dirname, program, params_filename)
|
|
|
|
|
|
|
|
|
|
if pserver_endpoints:
|
|
|
|
|
_endpoints_replacement(program, pserver_endpoints)
|
|
|
|
|
|
|
|
|
|
if training_role == "PSERVER":
|
|
|
|
|
_load_lookup_table_vars(executor, dirname, program, role_id)
|
|
|
|
|
|
|
|
|
|
feed_target_names = program.desc.get_feed_target_names()
|
|
|
|
|
fetch_target_names = program.desc.get_fetch_target_names()
|
|
|
|
|
fetch_targets = [
|
|
|
|
|
@ -745,6 +762,118 @@ def load_inference_model(dirname,
|
|
|
|
|
return [program, feed_target_names, fetch_targets]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_lookup_tables_by_notify(executor, dirname, lookup_table,
|
|
|
|
|
pserver_endpoints):
|
|
|
|
|
"""
|
|
|
|
|
This function will send checkpoint notify message from Trainer 0
|
|
|
|
|
to all the pservers.
|
|
|
|
|
The checkpoint notify message contains lookup table name,
|
|
|
|
|
the absolute path on pserver to save lookup_table.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
executor(Executor): The executor to run for send checkpoint notify.
|
|
|
|
|
dirname(str): The folder where to save.
|
|
|
|
|
lookup_table(string): the lookup table name, when use distribute
|
|
|
|
|
lookup table, we can get lookup table name by DistributeTranspiler.
|
|
|
|
|
table_name
|
|
|
|
|
ps_endpoint_list(list): the parameter server ip:port list.
|
|
|
|
|
when use distribute lookup table, we can get ps_endpoint_list by
|
|
|
|
|
distribute arguments.
|
|
|
|
|
Return:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
|
|
|
param_path = "./my_paddle_model"
|
|
|
|
|
table_name = "share_w"
|
|
|
|
|
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
|
|
|
|
|
|
|
|
|
|
_save_pserver_vars_by_notify(executor=exe,
|
|
|
|
|
dirname=param_path, lookup_table=table_name,
|
|
|
|
|
pserver_endpoints=ps_endpoints)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
pserver_notify_program = Program()
|
|
|
|
|
pserver_notify_block = pserver_notify_program.global_block()
|
|
|
|
|
|
|
|
|
|
attrs = {}
|
|
|
|
|
attrs['epmap'] = pserver_endpoints.split(",")
|
|
|
|
|
attrs['dir'] = dirname
|
|
|
|
|
attrs['lookup_table'] = lookup_table
|
|
|
|
|
|
|
|
|
|
pserver_notify_block.append_op(
|
|
|
|
|
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
|
|
|
|
|
executor.run(pserver_notify_program)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_lookup_table_vars(executor, dirname, program, pserver_id):
|
|
|
|
|
"""
|
|
|
|
|
The parameter server will load lookup table's local file in
|
|
|
|
|
selectedrows variable.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
executor(Executor): The executor to run for loading persistable variables
|
|
|
|
|
dirname(str): The directory path
|
|
|
|
|
main_program(Program): Find the variable named table_name in main_program
|
|
|
|
|
pserver_id(int): the serial number in pserver_endpoints list
|
|
|
|
|
table_name(str): lookup table name
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
|
|
|
dirname = "./checkpoints/checkpoint_9/"
|
|
|
|
|
prog = fluid.default_main_program()
|
|
|
|
|
pserver_id = 1
|
|
|
|
|
table_name = "share_w"
|
|
|
|
|
_load_lookup_table_vars(executor=exe,
|
|
|
|
|
dirname=dirname, program=prog, pserver_id=pserver_id,
|
|
|
|
|
table_name=table_name)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
LOOKUP_TABLE_TYPE = "lookup_table"
|
|
|
|
|
lookup_table_var_name = None
|
|
|
|
|
|
|
|
|
|
for op in program.global_block().ops:
|
|
|
|
|
if op.type == LOOKUP_TABLE_TYPE:
|
|
|
|
|
if op.attrs['is_distributed'] is True:
|
|
|
|
|
if lookup_table_var_name is None:
|
|
|
|
|
lookup_table_var_name = op.input("W")[0]
|
|
|
|
|
if lookup_table_var_name != op.input("W")[0]:
|
|
|
|
|
raise RuntimeError("all distributed lookup_table_ops"
|
|
|
|
|
" should have only one table")
|
|
|
|
|
|
|
|
|
|
lookup_table_var = program.global_block().vars[lookup_table_var_name]
|
|
|
|
|
if lookup_table_var is None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
lookup_table_dir = os.path.join(dirname, "__lookup_table__")
|
|
|
|
|
table_file = "{}.{}".format(lookup_table_var.name, pserver_id)
|
|
|
|
|
|
|
|
|
|
load_prog = Program()
|
|
|
|
|
load_block = load_prog.global_block()
|
|
|
|
|
|
|
|
|
|
load_block.append_op(
|
|
|
|
|
type='load',
|
|
|
|
|
inputs={},
|
|
|
|
|
outputs={'Out': [lookup_table_var]},
|
|
|
|
|
attrs={'file_path': os.path.join(lookup_table_dir, table_file)})
|
|
|
|
|
|
|
|
|
|
executor.run(load_prog)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _endpoints_replacement(program, endpoints):
|
|
|
|
|
ENDPOINT_MAP = "epmap"
|
|
|
|
|
for op in program.global_block().ops:
|
|
|
|
|
if op.attrs.has_key(ENDPOINT_MAP):
|
|
|
|
|
op.attrs[ENDPOINT_MAP] = endpoints
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_parameter_value(para, executor):
|
|
|
|
|
"""
|
|
|
|
|
Get the LoDTensor value of the given parameter.
|
|
|
|
|
|