|
|
@ -495,11 +495,11 @@ def save_checkpoint(executor,
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save_trainer_args(cur_dir, trainer_id, trainer_args)
|
|
|
|
|
|
|
|
|
|
|
|
if is_chief:
|
|
|
|
if is_chief:
|
|
|
|
save_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
save_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
|
|
|
|
_lru_delete(checkpoint_dir, max_num_checkpoints)
|
|
|
|
save_trainer_args(cur_dir, trainer_id, trainer_args)
|
|
|
|
|
|
|
|
_lru_delete(checkpoint_dir, max_num_checkpoints)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def need_load_checkpoint(checkpoint_dir):
|
|
|
|
def need_load_checkpoint(checkpoint_dir):
|
|
|
@ -639,7 +639,13 @@ def _is_checkpoint_var(var):
|
|
|
|
var.desc.type() == core.VarDesc.VarType.RAW:
|
|
|
|
var.desc.type() == core.VarDesc.VarType.RAW:
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
if var.name.endswith("@GRAD"):
|
|
|
|
if "@GRAD" in var.name:
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ".trainer_" in var.name:
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ".block" in var.name:
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
return var.persistable
|
|
|
|
return var.persistable
|
|
|
|