|
|
|
@ -840,6 +840,12 @@ def save_checkpoint(executor,
|
|
|
|
|
max_num_checkpoints(int): The max number of total number of existing
|
|
|
|
|
checkpoints.
|
|
|
|
|
Default: 3
|
|
|
|
|
lookup_table(string|None): the lookup table name, when use distribute
|
|
|
|
|
lookup table, we can get lookup table name by DistributeTranspiler.
|
|
|
|
|
table_name
|
|
|
|
|
ps_endpoint_list(list|None): the parameter server ip:port list.
|
|
|
|
|
when use distribute lookup table, we can get ps_endpoint_list by
|
|
|
|
|
distribute arguments.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
@ -856,15 +862,21 @@ def save_checkpoint(executor,
|
|
|
|
|
prog = fluid.default_main_program()
|
|
|
|
|
trainer_args = {"epoch_id": 200,
|
|
|
|
|
"step_id": 20} # just an example
|
|
|
|
|
table_name = "share_w"
|
|
|
|
|
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
|
|
|
|
|
|
|
|
|
|
fluid.io.save_checkpoint(executor=exe,
|
|
|
|
|
checkpoint_dir=path,
|
|
|
|
|
trainer_id=0,
|
|
|
|
|
trainer_args=trainer_args,
|
|
|
|
|
main_program=prog,
|
|
|
|
|
max_num_checkpoints=3)
|
|
|
|
|
max_num_checkpoints=3,
|
|
|
|
|
lookup_table=table_name,
|
|
|
|
|
ps_endpoint_list = ps_endpoints)
|
|
|
|
|
"""
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
raise ValueError("'checkpoint_dir' should not be None")
|
|
|
|
|
assert checkpoint_dir
|
|
|
|
|
|
|
|
|
|
if trainer_args:
|
|
|
|
|
assert isinstance(trainer_args, dict)
|
|
|
|
@ -881,6 +893,7 @@ def save_checkpoint(executor,
|
|
|
|
|
|
|
|
|
|
if is_chief:
|
|
|
|
|
save_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
|
|
|
|
|
|
if is_chief and lookup_table and ps_endpoint_list:
|
|
|
|
|
save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
|
|
|
|
|
ps_endpoint_list)
|
|
|
|
@ -1020,6 +1033,31 @@ def load_persist_vars_without_grad(executor,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
|
|
|
|
|
"""
|
|
|
|
|
The parameter server will load lookup table's local file in
|
|
|
|
|
selectedrows variable.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
executor(Executor): The executor to run for loading persistable variables
|
|
|
|
|
dirname(str): The directory path
|
|
|
|
|
main_program(Program): Find the variable named table_name in main_program
|
|
|
|
|
pserver_id(int): the serial number in pserver_endpoints list
|
|
|
|
|
table_name(str): lookup table name
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
|
|
|
dirname = "./checkpoints/checkpoint_9/__model__"
|
|
|
|
|
prog = fluid.default_main_program()
|
|
|
|
|
pserver_id = 1
|
|
|
|
|
table_name = "share_w"
|
|
|
|
|
fluid.io.load_lookup_table_vars(executor=exe,
|
|
|
|
|
dirname=dirname, program=prog, pserver_id=pserver_id,
|
|
|
|
|
table_name=table_name)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
for var in program.list_vars():
|
|
|
|
|
if var.name == table_name:
|
|
|
|
@ -1092,6 +1130,35 @@ def save_persist_vars_without_grad(executor, dirname, program):
|
|
|
|
|
def save_pserver_vars_by_notify(executor, dirname, lookup_table,
|
|
|
|
|
ps_endpoint_list):
|
|
|
|
|
"""
|
|
|
|
|
This function will send checkpoint notify message from Trainer 0
|
|
|
|
|
to all the pservers.
|
|
|
|
|
The checkpoint notify message contains lookup table name,
|
|
|
|
|
the absolute path on pserver to save lookup_table.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
executor(Executor): The executor to run for send checkpoint notify.
|
|
|
|
|
dirname(str): The folder where to save checkpoints.
|
|
|
|
|
lookup_table(string): the lookup table name, when use distribute
|
|
|
|
|
lookup table, we can get lookup table name by DistributeTranspiler.
|
|
|
|
|
table_name
|
|
|
|
|
ps_endpoint_list(list): the parameter server ip:port list.
|
|
|
|
|
when use distribute lookup table, we can get ps_endpoint_list by
|
|
|
|
|
distribute arguments.
|
|
|
|
|
Return:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
|
|
|
param_path = "./my_paddle_model"
|
|
|
|
|
prog = fluid.default_main_program()
|
|
|
|
|
table_name = "share_w"
|
|
|
|
|
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
|
|
|
|
|
|
|
|
|
|
fluid.io.save_pserver_vars_by_notify(executor=exe,
|
|
|
|
|
dirname=param_path, lookup_table=table_name,
|
|
|
|
|
ps_endpoint_list=ps_endpoints)
|
|
|
|
|
"""
|
|
|
|
|
cur_dir = _get_lookuptable_dir(dirname)
|
|
|
|
|
|
|
|
|
@ -1121,6 +1188,29 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
|
|
|
|
|
"""
|
|
|
|
|
trainer will load some args from it's independent directory,
|
|
|
|
|
such as epoch_id and step_id.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
checkpoint_dir(str): The folder where all checkpoints are.
|
|
|
|
|
serial(int): The serial of checkpoint you would like to load.
|
|
|
|
|
trainer_id(int): current trainer id.
|
|
|
|
|
trainer_args(list): list about load trainer args
|
|
|
|
|
Return:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
param_path = "./checkpoint/"
|
|
|
|
|
serial = 7
|
|
|
|
|
trainer_id = 2
|
|
|
|
|
trainer_args = ["epoch_id", "step_id"]
|
|
|
|
|
|
|
|
|
|
fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial,
|
|
|
|
|
trainer_id=trainer_id, trainer_args=trainer_args)
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(trainer_args, list)
|
|
|
|
|
|
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
|
|
@ -1141,7 +1231,7 @@ def _is_checkpoint_var(var):
|
|
|
|
|
the checkpoint will not save or load all the variables.
|
|
|
|
|
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
|
|
|
|
|
|
|
|
|
|
: param var
|
|
|
|
|
: param var(Variable)
|
|
|
|
|
"""
|
|
|
|
|
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
|
|
|
|
|