You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1259 lines
44 KiB
1259 lines
44 KiB
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import print_function
|
|
|
|
from ..wrapped_decorator import signature_safe_contextmanager
|
|
import os
|
|
import errno
|
|
import shutil
|
|
import six
|
|
import time
|
|
|
|
from .. import core
|
|
from .. import data_feeder
|
|
from .. import executor
|
|
from .. import framework
|
|
from .. import io
|
|
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
|
|
from .. import optimizer as opt_module
|
|
from .. import parallel_executor
|
|
from ..transpiler import distribute_transpiler
|
|
|
|
__all__ = [
|
|
'Trainer', 'BeginEpochEvent', 'EndEpochEvent', 'BeginStepEvent',
|
|
'EndStepEvent', 'CheckpointConfig'
|
|
]
|
|
|
|
|
|
class BeginEpochEvent(object):
|
|
"""
|
|
The begin of a training epoch.
|
|
|
|
Args:
|
|
epoch_id(int): The current epoch ID.
|
|
"""
|
|
|
|
def __init__(self, epoch_id):
|
|
self.epoch = epoch_id
|
|
|
|
|
|
class EndEpochEvent(object):
|
|
"""
|
|
The end of a training epoch.
|
|
|
|
Args:
|
|
epoch_id(int): The current epoch ID.
|
|
"""
|
|
|
|
def __init__(self, epoch_id):
|
|
self.epoch = epoch_id
|
|
|
|
|
|
class BeginStepEvent(object):
|
|
"""
|
|
The begin of a training epoch.
|
|
|
|
Args:
|
|
epoch_id(int): The current epoch ID.
|
|
step_id(int): The current step ID.
|
|
"""
|
|
|
|
def __init__(self, epoch_id, step_id):
|
|
self.epoch = epoch_id
|
|
self.step = step_id
|
|
self.fetch_metrics = True
|
|
"""
|
|
If fetch_metrics is true, the metrics will be fetched at the
|
|
EndStepEvent. Default is True.
|
|
"""
|
|
|
|
|
|
class EndStepEvent(object):
|
|
"""
|
|
The end of a training step.
|
|
|
|
Args:
|
|
epoch_id(int): The current epoch ID.
|
|
step_id(int): The current step ID.
|
|
metrics(list): A list of fetched tensor. The order of this list is same
|
|
as the :code:`train_func` returns.
|
|
"""
|
|
|
|
def __init__(self, epoch_id, step_id, metrics):
|
|
self.epoch = epoch_id
|
|
self.step = step_id
|
|
self.metrics = metrics
|
|
|
|
|
|
class CheckpointConfig(object):
|
|
"""
|
|
Parameter object for :code:`save_checkpoint` and
|
|
:code:`fluid.Trainer`. Used to configuration how to save checkpoint.
|
|
|
|
Args:
|
|
checkpoint_dir(str): Directory path to save check point. Default is the
|
|
current directory.
|
|
|
|
max_num_checkpoints(int): The max number of local check points.
|
|
epoch_interval(int): Every number of epoch to save check point.
|
|
step_interval(int): Every number of step to save check point.
|
|
|
|
Examples:
|
|
>>> config = fluid.CheckpointConfig("./checkpoints")
|
|
>>> trainer = fluid.Trainer(train_func=train_program,
|
|
>>> place=place,
|
|
>>> optimizer_func=optimizer_func,
|
|
>>> checkpoint_config=config)
|
|
>>> trainer.train(...)
|
|
"""
|
|
|
|
def __init__(self,
|
|
checkpoint_dir=None,
|
|
max_num_checkpoints=3,
|
|
epoch_interval=1,
|
|
step_interval=10):
|
|
|
|
assert epoch_interval >= 1
|
|
assert step_interval >= 1
|
|
|
|
self.checkpoint_dir = checkpoint_dir \
|
|
if checkpoint_dir is not None else os.getcwd()
|
|
self.max_num_checkpoints = max_num_checkpoints
|
|
self.epoch_interval = epoch_interval
|
|
self.step_interval = step_interval
|
|
self.epoch_id = 0
|
|
self.step_id = 0
|
|
self.load_serial = None
|
|
self.pserver_id = None
|
|
self.lookup_table_name = None
|
|
|
|
|
|
def check_and_get_place(place):
|
|
"""
|
|
Check the type of place or get the default place
|
|
Args:
|
|
place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on.
|
|
|
|
Raises:
|
|
TypeError if the type mismatched.
|
|
|
|
Returns:
|
|
the original place if it is not None.
|
|
if fluid is compiled with CUDA, returns CUDAPlace(0) by default.
|
|
Otherwise returns CPUPlace by default.
|
|
"""
|
|
if place is None:
|
|
if core.is_compiled_with_cuda():
|
|
return core.CUDAPlace(0)
|
|
else:
|
|
return core.CPUPlace()
|
|
else:
|
|
if not isinstance(place, core.CUDAPlace) and not isinstance(
|
|
place, core.CPUPlace):
|
|
raise TypeError("Place should be either CUDAPlace or CPUPlace")
|
|
return place
|
|
|
|
|
|
class Trainer(object):
|
|
"""
|
|
A trainer wraps MultiGPU/MultiNode training loops and can be used to train a
|
|
simple neural network easily.
|
|
|
|
This API takes a :code:`train_func`. A :code:`train_func` is a function that
|
|
return loss as it first return value. The reset value can be fetched by
|
|
EndStepEvent.metrics
|
|
|
|
This API also takes a :code:`optimizer_func` that will return an optimizer
|
|
instance.
|
|
|
|
For example, to train a MLP for MNIST dataset, the sample program is
|
|
|
|
>>> import paddle.fluid as fluid
|
|
>>>
|
|
>>> def mlp(image, layer_sizes=[200, 100], activation="relu", num_classes=10):
|
|
>>> hidden = image
|
|
>>> for layer_size in layer_sizes:
|
|
>>> hidden = fluid.layers.fc(input=hidden, size=layer_size, act=activation)
|
|
>>> return fluid.layers.fc(input=hidden, size=num_classes, act="softmax")
|
|
>>>
|
|
>>> def train_mnist_mlp():
|
|
>>> img = fluid.layers.data(name='image', shape=[784])
|
|
>>> label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
|
>>> prediction = mlp(img)
|
|
>>> return fluid.layers.mean(fluid.layers.cross_entropy(prediction, label))
|
|
>>>
|
|
>>> def optimizer():
|
|
>>> return fluid.optimizer.Adam()
|
|
>>>
|
|
>>> trainer = Trainer(train_func=train_mnist_mlp,
|
|
>>> optimizer_func=optimizer,
|
|
>>> place=fluid.CUDAPlace(0),
|
|
>>> parallel=True)
|
|
>>>
|
|
>>> def train_callback(event):
|
|
>>> if isinstance(event, fluid.EndStepEvent):
|
|
>>> print "Epoch ID", event.epoch, "Step ID",\
|
|
>>> event.step, "AvgLoss", event.metrics[0]
|
|
>>> elif isinstance(event, fluid.EndEpochEvent):
|
|
>>> trainer.save_params("./model_{0}".format(event.epoch))
|
|
>>>
|
|
>>> trainer.train(num_epochs=100, event_handler=train_callback)
|
|
|
|
For more example, please see :ref:`api_guide_high_level_api`.
|
|
|
|
|
|
Args:
|
|
train_func(callable): A function which will return loss. The loss must be
|
|
a scalar tensor.
|
|
optimizer_func(callable): A function that returns an Optimizer object.
|
|
place(CUDAPlace|CPUPlace): The device place of this trainer. If
|
|
:code:`parallel=True,` all CUDA Places will be used if :code:`place`
|
|
is a :code:`CUDAPlace`.
|
|
parallel(bool): True if use multiple devices.
|
|
checkpoint_config(CheckpointConfig): Configuration about how to save
|
|
checkpoints.
|
|
"""
|
|
|
|
def __init__(self,
|
|
train_func,
|
|
optimizer_func,
|
|
param_path=None,
|
|
place=None,
|
|
parallel=False,
|
|
checkpoint_config=None):
|
|
self.__stop = False
|
|
self.parallel = parallel
|
|
|
|
# config for checkpoint
|
|
# only chief worker will save variables
|
|
self.trainer_id = 0
|
|
self.checkpoint_cfg = checkpoint_config
|
|
if self.checkpoint_cfg:
|
|
assert isinstance(self.checkpoint_cfg, CheckpointConfig)
|
|
serial = _get_latest_checkpoint_serial(
|
|
self.checkpoint_cfg.checkpoint_dir)
|
|
self.checkpoint_cfg.load_serial = serial if serial >= 0 else None
|
|
|
|
self.scope = core.Scope()
|
|
|
|
# 1. we need to generate a framework.Program by calling
|
|
# program_func. Reference: fluid.program_guard in
|
|
# test_word2vec.py
|
|
|
|
self.startup_program = framework.Program()
|
|
self.train_program = framework.Program()
|
|
|
|
with framework.program_guard(self.train_program, self.startup_program):
|
|
program_func_outs = train_func()
|
|
self.train_func_outputs = program_func_outs if isinstance(
|
|
program_func_outs, list) else [program_func_outs]
|
|
self.test_program = self.train_program.clone(for_test=True)
|
|
|
|
# The first element of program_func_outs is loss.
|
|
loss = self.train_func_outputs[0]
|
|
|
|
optimizer = optimizer_func()
|
|
if not isinstance(optimizer, opt_module.Optimizer):
|
|
raise TypeError(
|
|
"The optimizer should be an instance of Optimizer")
|
|
optimize_ops, params_grads = optimizer.minimize(loss)
|
|
|
|
self.place = check_and_get_place(place)
|
|
|
|
self._dist_transpile_if_necessary(optimize_ops, params_grads)
|
|
|
|
# 2. move the default_main_program to self.program and run the
|
|
# default_startup program on an empty core.Scope()
|
|
# Run startup program
|
|
with self._prog_and_scope_guard():
|
|
exe = executor.Executor(place)
|
|
exe.run(self.startup_program)
|
|
|
|
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial is not None:
|
|
self._load_checkpoint()
|
|
|
|
if param_path and os.path.isdir(param_path):
|
|
with self._prog_and_scope_guard():
|
|
# load params from param_path into scope
|
|
io.load_persistables(
|
|
executor=exe,
|
|
dirname=param_path,
|
|
main_program=self.startup_program)
|
|
|
|
def _transpile_nccl2_dist(self):
|
|
# PADDLE_TRAINER_IPS
|
|
if "PADDLE_TRAINER_IPS" not in os.environ:
|
|
self.nccl_id_var = None
|
|
else:
|
|
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
|
|
port = os.getenv("PADDLE_PSERVER_PORT")
|
|
worker_ips = os.getenv("PADDLE_TRAINER_IPS")
|
|
worker_endpoints = []
|
|
for ip in worker_ips.split(","):
|
|
worker_endpoints.append(':'.join([ip, port]))
|
|
self.num_trainers = len(worker_endpoints)
|
|
current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port
|
|
worker_endpoints.remove(current_endpoint)
|
|
# TODO(wuyi): use self.nccl_id_var, self.num_trainers and self.trainer_id
|
|
# in ParallelExecutor to start
|
|
# distributed training using NCCL2
|
|
self.nccl_id_var = self.startup_program.global_block().create_var(
|
|
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
|
|
self.startup_program.global_block().append_op(
|
|
type="gen_nccl_id",
|
|
inputs={},
|
|
outputs={"NCCLID": self.nccl_id_var},
|
|
attrs={
|
|
"endpoint": current_endpoint,
|
|
"endpoint_list": worker_endpoints,
|
|
"trainer_id": self.trainer_id
|
|
})
|
|
|
|
def _dist_transpile_if_necessary(self, optimize_ops, params_grads):
|
|
self._transpile_nccl2_dist()
|
|
if self.nccl_id_var != None:
|
|
return
|
|
|
|
if "PADDLE_TRAINING_ROLE" not in os.environ:
|
|
return
|
|
|
|
# the port of all pservers, needed by both trainer and pserver
|
|
port = os.getenv("PADDLE_PSERVER_PORT", "6174")
|
|
# comma separated ips of all pservers, needed by trainer and
|
|
# pserver
|
|
pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
|
|
eplist = []
|
|
for ip in pserver_ips.split(","):
|
|
eplist.append(':'.join([ip, port]))
|
|
pserver_endpoints = ",".join(eplist)
|
|
# total number of workers/trainers in the job, needed by
|
|
# trainer and pserver
|
|
trainers = int(os.getenv("PADDLE_TRAINERS"))
|
|
# the IP of the local machine, needed by pserver only
|
|
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
|
|
# the unique trainer id, starting from 0, needed by trainer
|
|
# only
|
|
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
|
|
|
|
# the role, should be either PSERVER or TRAINER
|
|
training_role = os.getenv("PADDLE_TRAINING_ROLE")
|
|
with self._prog_and_scope_guard():
|
|
t = distribute_transpiler.DistributeTranspiler()
|
|
t.transpile(
|
|
self.trainer_id, pservers=pserver_endpoints, trainers=trainers)
|
|
if training_role == "PSERVER":
|
|
if self.checkpoint_cfg:
|
|
pserver_id = eplist.index(current_endpoint)
|
|
self.checkpoint_cfg.pserver_id = pserver_id
|
|
if t.has_distributed_lookup_table:
|
|
self.checkpoint_cfg.lookup_table_name = t.table_name
|
|
|
|
self.train_program = t.get_pserver_program(current_endpoint)
|
|
self.startup_program = t.get_startup_program(current_endpoint,
|
|
self.train_program)
|
|
elif training_role == "TRAINER":
|
|
self.train_program = t.get_trainer_program()
|
|
else:
|
|
raise ValueError(
|
|
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
|
|
)
|
|
|
|
def stop(self):
|
|
"""
|
|
stop training
|
|
"""
|
|
self.__stop = True
|
|
|
|
def train(self, num_epochs, event_handler, reader=None, feed_order=None):
|
|
"""
|
|
Start the train loop to train the model.
|
|
|
|
Args:
|
|
num_epochs(int): The number of epoch. An epoch will process all data in reader
|
|
event_handler(callable): The event handler. A function with type (ev:Event)->void
|
|
reader(callable): A reader creator object. See also
|
|
:ref:`api_guide_python_reader` .
|
|
feed_order(list): Feeding order of reader. None will following the defining
|
|
order in program
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
training_role = os.getenv("PADDLE_TRAINING_ROLE", "")
|
|
if training_role == "PSERVER":
|
|
with self._prog_and_scope_guard():
|
|
exe = executor.Executor(self.place)
|
|
exe.run()
|
|
return
|
|
if self.parallel:
|
|
self._train_by_parallel_executor(num_epochs, event_handler, reader,
|
|
feed_order)
|
|
else:
|
|
self._train_by_executor(num_epochs, event_handler, reader,
|
|
feed_order)
|
|
|
|
def test(self, reader, feed_order):
|
|
"""
|
|
Test the model on given test data
|
|
|
|
Args:
|
|
reader(callable): The reader that yields test data.
|
|
feed_order(list): Feeding order of reader. None will following the
|
|
defining order in program
|
|
"""
|
|
|
|
return self._test_by_executor(reader, feed_order,
|
|
self.train_func_outputs)
|
|
|
|
def save_params(self, param_path):
|
|
"""
|
|
Save all parameters into :code:`param_path`.
|
|
|
|
Args:
|
|
param_path(str): The path to save parameters.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
with self._prog_and_scope_guard():
|
|
exe = executor.Executor(self.place)
|
|
io.save_persistables(exe, dirname=param_path)
|
|
|
|
def save_inference_model(self, param_path, feeded_var_names,
|
|
target_var_indexes):
|
|
"""
|
|
Save model for cpp inference into :code:`param_path`.
|
|
|
|
Args:
|
|
param_path(str): The path to save parameters.
|
|
feeded_var_names(list(str)): The name of the vars that you
|
|
need to feed in before run program.
|
|
target_var_indexes(list(int)): the index of target var that
|
|
you need to return in trainer.train_func.
|
|
Returns:
|
|
None
|
|
"""
|
|
with self._prog_and_scope_guard():
|
|
exe = executor.Executor(self.place)
|
|
target_vars = [
|
|
self.train_func_outputs[index] for index in target_var_indexes
|
|
]
|
|
io.save_inference_model(param_path, feeded_var_names, target_vars,
|
|
exe)
|
|
|
|
@signature_safe_contextmanager
|
|
def _prog_and_scope_guard(self):
|
|
with framework.program_guard(
|
|
main_program=self.train_program,
|
|
startup_program=self.startup_program):
|
|
with executor.scope_guard(self.scope):
|
|
yield
|
|
|
|
def _train_by_executor(self, num_epochs, event_handler, reader, feed_order):
|
|
"""
|
|
Train by Executor and single device.
|
|
|
|
Args:
|
|
num_epochs:
|
|
event_handler:
|
|
reader:
|
|
feed_order:
|
|
|
|
Returns:
|
|
|
|
"""
|
|
with self._prog_and_scope_guard():
|
|
feed_var_list = build_feed_var_list(self.train_program, feed_order)
|
|
feeder = data_feeder.DataFeeder(
|
|
feed_list=feed_var_list, place=self.place)
|
|
exe = executor.Executor(self.place)
|
|
reader = feeder.decorate_reader(reader, multi_devices=False)
|
|
self._train_by_any_executor(event_handler, exe, num_epochs, reader)
|
|
|
|
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
|
|
if self.checkpoint_cfg:
|
|
epochs = [
|
|
epoch_id for epoch_id in range(num_epochs)
|
|
if epoch_id >= self.checkpoint_cfg.epoch_id
|
|
]
|
|
else:
|
|
epochs = [epoch_id for epoch_id in range(num_epochs)]
|
|
|
|
for epoch_id in epochs:
|
|
event_handler(BeginEpochEvent(epoch_id))
|
|
for step_id, data in enumerate(reader()):
|
|
if self.__stop:
|
|
if self.checkpoint_cfg:
|
|
self._clean_checkpoint()
|
|
return
|
|
|
|
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial \
|
|
and self.checkpoint_cfg.step_id >= step_id and self.checkpoint_cfg.epoch_id == epoch_id:
|
|
continue
|
|
|
|
begin_event = BeginStepEvent(epoch_id, step_id)
|
|
event_handler(begin_event)
|
|
if begin_event.fetch_metrics:
|
|
metrics = exe.run(feed=data,
|
|
fetch_list=[
|
|
var.name
|
|
for var in self.train_func_outputs
|
|
])
|
|
else:
|
|
metrics = exe.run(feed=data, fetch_list=[])
|
|
|
|
if self.checkpoint_cfg:
|
|
self._save_checkpoint(epoch_id, step_id)
|
|
event_handler(EndStepEvent(epoch_id, step_id, metrics))
|
|
event_handler(EndEpochEvent(epoch_id))
|
|
if self.checkpoint_cfg:
|
|
self._clean_checkpoint()
|
|
|
|
def _test_by_executor(self, reader, feed_order, fetch_list):
|
|
with executor.scope_guard(self.scope):
|
|
feed_var_list = build_feed_var_list(self.test_program, feed_order)
|
|
feeder = data_feeder.DataFeeder(
|
|
feed_list=feed_var_list, place=self.place)
|
|
exe = executor.Executor(self.place)
|
|
accumulated = len(fetch_list) * [0]
|
|
count = 0
|
|
for data in reader():
|
|
outs = exe.run(program=self.test_program,
|
|
feed=feeder.feed(data),
|
|
fetch_list=fetch_list)
|
|
accumulated = [x[0] + x[1][0] for x in zip(accumulated, outs)]
|
|
count += 1
|
|
|
|
return [x / count for x in accumulated]
|
|
|
|
def _train_by_parallel_executor(self, num_epochs, event_handler, reader,
|
|
feed_order):
|
|
with self._prog_and_scope_guard():
|
|
pe = self._get_or_create_parallel_executor()
|
|
feed_var_list = build_feed_var_list(self.train_program, feed_order)
|
|
feeder = data_feeder.DataFeeder(
|
|
feed_list=feed_var_list, place=self.place)
|
|
reader = feeder.decorate_reader(reader, multi_devices=True)
|
|
self._train_by_any_executor(event_handler, pe, num_epochs, reader)
|
|
|
|
def _get_parallel_executor(self):
|
|
return getattr(self, 'parallel_executor', None)
|
|
|
|
def _get_or_create_parallel_executor(self):
|
|
if self._get_parallel_executor() is None:
|
|
self.parallel_executor = parallel_executor.ParallelExecutor(
|
|
use_cuda=isinstance(self.place, core.CUDAPlace),
|
|
loss_name=self.train_func_outputs[0].name)
|
|
return self._get_parallel_executor()
|
|
|
|
def _clean_checkpoint(self):
|
|
assert self.checkpoint_cfg
|
|
clean_checkpoint(checkpoint_dir=self.checkpoint_cfg.checkpoint_dir)
|
|
|
|
def _get_checkpoint_load_args(self):
|
|
"""
|
|
epoch_id and step_id are runtime arguments, they are not variables, will load them independently.
|
|
"""
|
|
return ["epoch_id", "step_id"]
|
|
|
|
def _get_checkpoint_save_args(self, epoch_id, step_id):
|
|
"""
|
|
epoch_id and step_id are runtime arguments, they are not variables, will save them independently.
|
|
"""
|
|
trainer_args = {}
|
|
trainer_args["epoch_id"] = epoch_id
|
|
trainer_args["step_id"] = step_id
|
|
return trainer_args
|
|
|
|
def _save_checkpoint(self, epoch_id, step_id):
|
|
assert self.checkpoint_cfg
|
|
|
|
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \
|
|
and step_id % self.checkpoint_cfg.step_interval == 0:
|
|
exe = executor.Executor(self.place)
|
|
save_checkpoint(
|
|
executor=exe,
|
|
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
|
|
trainer_id=self.trainer_id,
|
|
trainer_args=self._get_checkpoint_save_args(epoch_id, step_id),
|
|
main_program=self.train_program,
|
|
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
|
|
|
|
def _load_checkpoint(self):
|
|
with self._prog_and_scope_guard():
|
|
exe = executor.Executor(self.place)
|
|
load_checkpoint(
|
|
executor=exe,
|
|
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
|
|
main_program=self.startup_program)
|
|
|
|
if not self.checkpoint_cfg.pserver_id:
|
|
load_trainer_args = self._get_checkpoint_load_args()
|
|
trainer_args = load_checkpoint(
|
|
executor=exe,
|
|
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
|
|
main_program=self.startup_program,
|
|
role_id=self.trainer_id,
|
|
is_trainer=True,
|
|
load_trainer_args=load_trainer_args)
|
|
|
|
if len(trainer_args) != 2:
|
|
raise ValueError(
|
|
"the return trainer_args length do not equal _get_checkpoint_load_args"
|
|
)
|
|
self.checkpoint_cfg.epoch_id = int(trainer_args[0])
|
|
self.checkpoint_cfg.step_id = int(trainer_args[1])
|
|
else:
|
|
if self.checkpoint_cfg.lookup_table_name:
|
|
load_checkpoint(
|
|
executor=exe,
|
|
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
|
|
main_program=self.startup_program,
|
|
role_id=self.checkpoint_cfg.pserver_id,
|
|
is_trainer=False,
|
|
load_trainer_args=None,
|
|
load_lookup_table=self.checkpoint_cfg.lookup_table_name)
|
|
|
|
|
|
def build_feed_var_list(program, feed_order):
|
|
if not isinstance(program, framework.Program):
|
|
raise TypeError("The 'program' should be an object of Program")
|
|
|
|
if isinstance(feed_order, list):
|
|
feed_var_list = [
|
|
program.global_block().var(var_name) for var_name in feed_order
|
|
]
|
|
else:
|
|
if not isinstance(feed_order, dict):
|
|
raise TypeError(
|
|
"The 'feed_order' should be either None, list or dict.")
|
|
if not sorted(feed_order.values()) == list(range(len(feed_order))):
|
|
raise ValueError(
|
|
"The values of 'feed_order' should be a permutation of [0, len(feed_order))"
|
|
)
|
|
sorted_pair_list = sorted(
|
|
six.iteritems(feed_order), key=lambda item: item[1])
|
|
feed_var_list = [
|
|
program.global_block().var(pair[0]) for pair in sorted_pair_list
|
|
]
|
|
return feed_var_list
|
|
|
|
|
|
# move Checkpoint APIs from io.py to trainer.py, make all of them are private.
|
|
SUCCESS_MARK_FILENAME = "_SUCCESS"
|
|
CHECKPOINT_PREFIX = "checkpoint"
|
|
MODEL_DIR = "__model__"
|
|
LOOKUP_TABLE_DIR = "__lookup_table__"
|
|
TRAINER_PREFIX = "trainer"
|
|
CHECKPOINT_SEPARATOR = "_"
|
|
|
|
|
|
def save_checkpoint(executor,
|
|
checkpoint_dir,
|
|
trainer_id,
|
|
main_program,
|
|
trainer_args=None,
|
|
max_num_checkpoints=3,
|
|
lookup_table=None,
|
|
pserver_endpoints=None):
|
|
"""
|
|
This function filters out all checkpoint variables from the give
|
|
main_program and then saves these variables to the `checkpoint_dir`
|
|
directory.
|
|
|
|
In the training precess, we generally save a checkpoint in each
|
|
iteration. So there might be a lot of checkpoints in the
|
|
`checkpoint_dir`. To avoid them taking too much disk space, the
|
|
`max_num_checkpoints` are introduced to limit the total number of
|
|
checkpoints. If the number of existing checkpints is greater than
|
|
the `max_num_checkpoints`, oldest ones will be scroll deleted.
|
|
|
|
A variable is a checkpoint variable and will be saved if it meets
|
|
all following conditions:
|
|
1. It's persistable.
|
|
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
|
|
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
|
|
|
|
Args:
|
|
executor(Executor): The executor to run for save checkpoint.
|
|
checkpoint_dir(str): The folder where to save checkpoints.
|
|
trainer_id(int): currect trainer id, if id is equal to 0, the trainer
|
|
is chief.
|
|
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
|
|
and 'step_id'.
|
|
Defaut: None
|
|
main_program(Program): The program whose checkpoint variables will
|
|
be saved.
|
|
max_num_checkpoints(int): The max number of total number of existing
|
|
checkpoints.
|
|
Default: 3
|
|
lookup_table(string|None): the lookup table name, when use distribute
|
|
lookup table, we can get lookup table name by DistributeTranspiler.
|
|
table_name
|
|
pserver_endpoints(list|None): the parameter server ip:port list.
|
|
when use distribute lookup table, we can get pserver_endpoints by
|
|
distribute arguments.
|
|
|
|
Returns:
|
|
None
|
|
|
|
Raises:
|
|
ValueError: If `checkpoint_dir` is None.
|
|
AssertionError: If `trainer_args` is not a dict.
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
path = "./checkpoints"
|
|
prog = fluid.default_main_program()
|
|
trainer_args = {"epoch_id": 200,
|
|
"step_id": 20} # just an example
|
|
table_name = "share_w"
|
|
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
|
|
|
|
save_checkpoint(executor=exe,
|
|
checkpoint_dir=path,
|
|
trainer_id=0,
|
|
trainer_args=trainer_args,
|
|
main_program=prog,
|
|
max_num_checkpoints=3,
|
|
lookup_table=table_name,
|
|
pserver_endpoints = ps_endpoints)
|
|
"""
|
|
if checkpoint_dir is None:
|
|
raise ValueError("'checkpoint_dir' should not be None")
|
|
|
|
if main_program is None:
|
|
raise ValueError('main_program should not be None.')
|
|
|
|
if trainer_args:
|
|
assert isinstance(trainer_args, dict)
|
|
|
|
is_chief = trainer_id == 0
|
|
|
|
_make_chekcpoint_dirs(checkpoint_dir)
|
|
serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
|
|
_save_trainer_args(cur_dir, trainer_id, trainer_args)
|
|
|
|
if is_chief:
|
|
_save_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
if is_chief and lookup_table and pserver_endpoints:
|
|
_save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
|
|
pserver_endpoints)
|
|
|
|
_scroll_delete(checkpoint_dir, max_num_checkpoints)
|
|
|
|
|
|
def load_checkpoint(executor,
|
|
checkpoint_dir,
|
|
main_program,
|
|
role_id=0,
|
|
is_trainer=True,
|
|
load_trainer_args=None,
|
|
load_lookup_table=None):
|
|
"""
|
|
This function filters out all checkpoint variables from the give
|
|
main_program and then try to load these variables from the
|
|
`checkpoint_dir` directory.
|
|
|
|
In the training precess, we generally save a checkpoint in each
|
|
iteration. So there are more than one checkpoint in the
|
|
`checkpoint_dir` (each checkpoint has its own sub folder), use
|
|
`serial` to specify which serial of checkpoint you would like to
|
|
load.
|
|
|
|
A variable is a checkpoint variable and will be loaded if it meets
|
|
all following conditions:
|
|
1. It's persistable.
|
|
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
|
|
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
|
|
|
|
Args:
|
|
executor(Executor): The executor to run for loading checkpoint.
|
|
checkpoint_dir(str): The folder where all checkpoints are.
|
|
serial(int): The serial of checkpoint you would like to load.
|
|
main_program(Program): The program whose checkpoint variables will
|
|
be loaded.
|
|
role_id(int): the trainer id or the parameter server id.
|
|
is_trainer(bool): trainer is True and parameter server is False.
|
|
load_trainer_args(list|None): list about load trainer args.
|
|
load_lookup_table(str|None): the lookup table name
|
|
|
|
Returns:
|
|
None
|
|
|
|
Raises:
|
|
ValueError: If `checkpoint_dir` is None.
|
|
ValueError: If `main_program` is None.
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
path = "./checkpoints"
|
|
prog = fluid.default_main_program()
|
|
load_checkpoint(executor=exe, checkpoint_dir=path,
|
|
serial=9, main_program=prog)
|
|
|
|
# In this example, `load_checkpoint` function
|
|
# will first filters out all checkpoint variables in the default
|
|
# main program, and then try to load these variables form the
|
|
# folder "./checkpoints/checkpoint_9/__model__".
|
|
"""
|
|
|
|
if checkpoint_dir is None:
|
|
raise ValueError("'checkpoint_dir' should not be None")
|
|
|
|
serial = _get_latest_checkpoint_serial(checkpoint_dir)
|
|
|
|
# there are nothing need to be loaded
|
|
if serial is None or serial < 0:
|
|
return
|
|
|
|
if main_program is None:
|
|
raise ValueError('main_program should not be None.')
|
|
|
|
if is_trainer and load_trainer_args is None:
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
_load_persist_vars_without_grad(executor, cur_dir, main_program, True)
|
|
return
|
|
|
|
if is_trainer and load_trainer_args:
|
|
return _load_trainer_args(checkpoint_dir, serial, role_id,
|
|
load_trainer_args)
|
|
|
|
if not is_trainer and load_lookup_table:
|
|
_load_lookup_table_vars(executor, checkpoint_dir, main_program, role_id,
|
|
load_lookup_table)
|
|
|
|
|
|
def clean_checkpoint(checkpoint_dir, delete_dir=False):
|
|
"""
|
|
clean the checkpoint dir, when the train exits normally,
|
|
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
|
|
delete_dir only works when the directory is empty, otherwise, OSError is raised.
|
|
|
|
: param checkpoint_dir
|
|
: param delete_dir
|
|
"""
|
|
|
|
if checkpoint_dir is None:
|
|
raise ValueError("'checkpoint_dir' should not be None")
|
|
_scroll_delete(checkpoint_dir, max_num_checkpoints=0)
|
|
|
|
if delete_dir and not os.listdir(checkpoint_dir):
|
|
os.rmdir(checkpoint_dir)
|
|
|
|
|
|
def _load_persist_vars_without_grad(executor,
|
|
dirname,
|
|
program,
|
|
has_model_dir=False):
|
|
"""
|
|
This function filters out all checkpoint variables from the give
|
|
program and then trys to load these variables from the given directory.
|
|
|
|
A variable is a checkpoint variable if it meets all following
|
|
conditions:
|
|
1. It's persistable.
|
|
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
|
|
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
|
|
|
|
Args:
|
|
executor(Executor): The executor to run for loading variables.
|
|
dirname(str): The directory path.
|
|
program(Program): The program whose checkpoint variables will
|
|
be loaded.
|
|
has_model_dir(bool): if True, the function loads variables
|
|
from a sub directory named '__model__'.
|
|
Default: False
|
|
|
|
Returns:
|
|
None
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
param_path = "./my_paddle_model"
|
|
prog = fluid.default_main_program()
|
|
_load_persist_vars_without_grad(executor=exe,
|
|
dirname=param_path, program=prog, has_model_dir=True)
|
|
|
|
# In this example, `_load_persist_vars_without_grad` function
|
|
# will first filters out all checkpoint variables in the default
|
|
# main program, and then trys to load these variables form the
|
|
# folder "./my_paddle_model/__model__".
|
|
"""
|
|
|
|
if has_model_dir:
|
|
dirname = _get_model_dir(dirname)
|
|
|
|
io.load_vars(
|
|
executor,
|
|
dirname=dirname,
|
|
main_program=program,
|
|
predicate=_is_checkpoint_var,
|
|
filename=None)
|
|
|
|
|
|
def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
|
|
"""
|
|
The parameter server will load lookup table's local file in
|
|
selectedrows variable.
|
|
|
|
Args:
|
|
executor(Executor): The executor to run for loading persistable variables
|
|
dirname(str): The directory path
|
|
main_program(Program): Find the variable named table_name in main_program
|
|
pserver_id(int): the serial number in pserver_endpoints list
|
|
table_name(str): lookup table name
|
|
|
|
Returns:
|
|
None
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
dirname = "./checkpoints/checkpoint_9/"
|
|
prog = fluid.default_main_program()
|
|
pserver_id = 1
|
|
table_name = "share_w"
|
|
_load_lookup_table_vars(executor=exe,
|
|
dirname=dirname, program=prog, pserver_id=pserver_id,
|
|
table_name=table_name)
|
|
"""
|
|
|
|
for var in program.list_vars():
|
|
if var.name == table_name:
|
|
lookup_table_var = var
|
|
break
|
|
|
|
assert lookup_table_var is not None
|
|
|
|
lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
|
|
table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id)
|
|
|
|
load_prog = framework.Program()
|
|
load_block = load_prog.global_block()
|
|
|
|
load_block.append_op(
|
|
type='load',
|
|
inputs={},
|
|
outputs={'Out': [lookup_table_var]},
|
|
attrs={'file_path': os.path.join(lookup_table_dir, table_file)})
|
|
|
|
executor.run(load_prog)
|
|
|
|
|
|
def _save_persist_vars_without_grad(executor, dirname, program):
|
|
"""
|
|
This function filters out all checkpoint variables from the give
|
|
program and then save these variables to a sub-folder '__model__' of
|
|
the given directory.
|
|
|
|
A variable is a checkpoint variable if it meets all following
|
|
conditions:
|
|
1. It's persistable.
|
|
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
|
|
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
|
|
|
|
Args:
|
|
executor(Executor): The executor to run for saving variables.
|
|
dirname(str): The directory path.
|
|
program(Program): The program whose checkpoint variables will
|
|
be saved.
|
|
|
|
Returns:
|
|
None
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
param_path = "./my_paddle_model"
|
|
prog = fluid.default_main_program()
|
|
_save_persist_vars_without_grad(executor=exe,
|
|
dirname=param_path, program=prog)
|
|
|
|
# In this example, `_save_persist_vars_without_grad` function
|
|
# will first filters out all checkpoint variables in the default
|
|
# main program, and then saves these variables to the folder
|
|
# "./my_paddle_model/__model__".
|
|
"""
|
|
cur_dir = _get_model_dir(dirname)
|
|
io.save_vars(
|
|
executor,
|
|
dirname=cur_dir,
|
|
main_program=program,
|
|
vars=None,
|
|
predicate=_is_checkpoint_var,
|
|
filename=None)
|
|
_write_success(cur_dir)
|
|
|
|
|
|
def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
|
|
ps_endpoint_list):
|
|
"""
|
|
This function will send checkpoint notify message from Trainer 0
|
|
to all the pservers.
|
|
The checkpoint notify message contains lookup table name,
|
|
the absolute path on pserver to save lookup_table.
|
|
|
|
Args:
|
|
executor(Executor): The executor to run for send checkpoint notify.
|
|
dirname(str): The folder where to save checkpoints.
|
|
lookup_table(string): the lookup table name, when use distribute
|
|
lookup table, we can get lookup table name by DistributeTranspiler.
|
|
table_name
|
|
ps_endpoint_list(list): the parameter server ip:port list.
|
|
when use distribute lookup table, we can get ps_endpoint_list by
|
|
distribute arguments.
|
|
Return:
|
|
None
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
param_path = "./my_paddle_model"
|
|
prog = fluid.default_main_program()
|
|
table_name = "share_w"
|
|
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
|
|
|
|
_save_pserver_vars_by_notify(executor=exe,
|
|
dirname=param_path, lookup_table=table_name,
|
|
ps_endpoint_list=ps_endpoints)
|
|
"""
|
|
cur_dir = _get_lookuptable_dir(dirname)
|
|
|
|
checkpoint_notify_program = framework.Program()
|
|
checkpoint_notify_block = checkpoint_notify_program.global_block()
|
|
|
|
attrs = {}
|
|
attrs['epmap'] = ps_endpoint_list
|
|
attrs['dir'] = cur_dir
|
|
attrs['lookup_table'] = lookup_table
|
|
|
|
checkpoint_notify_block.append_op(
|
|
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
|
|
executor.run(checkpoint_notify_program)
|
|
|
|
|
|
def _save_trainer_args(dirname, trainer_id, trainer_args):
|
|
assert isinstance(trainer_args, dict)
|
|
|
|
cur_dir = _get_trainer_dir(dirname, trainer_id)
|
|
|
|
for name, value in six.iteritems(trainer_args):
|
|
args_file = os.path.join(cur_dir, name)
|
|
with open(args_file, 'w') as f:
|
|
f.write(str(value))
|
|
_write_success(cur_dir)
|
|
|
|
|
|
def _load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
|
|
"""
|
|
trainer will load some args from it's independent directory,
|
|
such as epoch_id and step_id.
|
|
|
|
Args:
|
|
checkpoint_dir(str): The folder where all checkpoints are.
|
|
serial(int): The serial of checkpoint you would like to load.
|
|
trainer_id(int): current trainer id.
|
|
trainer_args(list): list about load trainer args
|
|
Return:
|
|
None
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
param_path = "./checkpoint/"
|
|
serial = 7
|
|
trainer_id = 2
|
|
trainer_args = ["epoch_id", "step_id"]
|
|
|
|
_load_trainer_args(checkpoint_dir=param_path, serial=serial,
|
|
trainer_id=trainer_id, trainer_args=trainer_args)
|
|
"""
|
|
assert isinstance(trainer_args, list)
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
cur_dir = _get_trainer_dir(cur_dir, trainer_id)
|
|
|
|
ret_values = []
|
|
|
|
for arg in trainer_args:
|
|
cur_file = os.path.join(cur_dir, arg)
|
|
with open(cur_file, 'r') as f:
|
|
contents = f.read()
|
|
ret_values.append(contents.strip())
|
|
return ret_values
|
|
|
|
|
|
def _is_checkpoint_var(var):
|
|
"""
|
|
the checkpoint will not save or load all the variables.
|
|
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
|
|
|
|
: param var(Variable)
|
|
"""
|
|
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
|
|
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
|
|
var.desc.type() == core.VarDesc.VarType.RAW:
|
|
return False
|
|
# @GRAD are named for gradient variables, checkpoint will not save it.
|
|
if "@GRAD" in var.name:
|
|
return False
|
|
# .trainer_ are named for distribute train variables, checkpoint will not save it.
|
|
if ".trainer_" in var.name:
|
|
return False
|
|
|
|
# .block is named for distribute train variables, checkpoint will not save it.
|
|
if ".block" in var.name:
|
|
return False
|
|
|
|
return var.persistable
|
|
|
|
|
|
def _make_chekcpoint_dirs(dirs):
|
|
"""
|
|
_make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it.
|
|
"""
|
|
assert dirs is not None
|
|
|
|
if os.path.isfile(dirs):
|
|
raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs)
|
|
|
|
if not os.path.isdir(dirs):
|
|
try:
|
|
os.makedirs(dirs)
|
|
except OSError as err:
|
|
if err.errno != errno.EEXIST:
|
|
raise err
|
|
|
|
|
|
def _get_dir_serial(dirname):
|
|
_, serial = dirname.split(CHECKPOINT_SEPARATOR)
|
|
|
|
try:
|
|
serial_num = int(serial)
|
|
except ValueError:
|
|
serial_num = -1
|
|
return serial_num
|
|
|
|
|
|
def _get_serial_dir(dirname, serial):
|
|
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
|
|
serial_dir = os.path.join(dirname, serial_folder)
|
|
_make_chekcpoint_dirs(serial_dir)
|
|
|
|
return serial_dir
|
|
|
|
|
|
def _get_model_dir(dirname):
|
|
model_dir = os.path.join(dirname, MODEL_DIR)
|
|
_make_chekcpoint_dirs(model_dir)
|
|
return model_dir
|
|
|
|
|
|
def _get_lookuptable_dir(dirname):
|
|
lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
|
|
_make_chekcpoint_dirs(lookuptable_dir)
|
|
return lookuptable_dir
|
|
|
|
|
|
def _get_trainer_dir(dirname, trainer_id):
|
|
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
|
|
trainer_dir = os.path.join(dirname, trainer_folder)
|
|
_make_chekcpoint_dirs(trainer_dir)
|
|
return trainer_dir
|
|
|
|
|
|
def _scroll_delete(dirname, max_num_checkpoints=3):
|
|
dirs = os.listdir(dirname)
|
|
serial_map = {}
|
|
for serial in dirs:
|
|
serial_num = _get_dir_serial(serial)
|
|
serial_map[serial_num] = serial
|
|
|
|
if len(list(serial_map.keys())) <= max_num_checkpoints:
|
|
return
|
|
|
|
serials = list(serial_map.keys())
|
|
serials.sort(reverse=True)
|
|
serials = serials[max_num_checkpoints:]
|
|
for serial in serials:
|
|
cur_dir = _get_serial_dir(dirname, serial)
|
|
try:
|
|
shutil.rmtree(cur_dir)
|
|
except OSError as err:
|
|
if err.errno != errno.ENOENT:
|
|
raise err
|
|
|
|
|
|
def _write_success(dirname):
|
|
"""
|
|
write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
|
|
|
|
: param dirname
|
|
"""
|
|
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
|
|
with open(success_file, 'a') as f:
|
|
now = time.ctime()
|
|
f.write(now)
|
|
|
|
|
|
def _get_latest_checkpoint_serial(checkpoint_dir):
|
|
"""
|
|
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
|
|
|
|
: param checkpoint_dir
|
|
"""
|
|
if not checkpoint_dir:
|
|
return -1
|
|
|
|
def has_success(checkpoint_dir, cur_dir):
|
|
"""
|
|
is _SUCCESS in this dir
|
|
"""
|
|
|
|
serial = _get_dir_serial(cur_dir)
|
|
if serial == -1 or not os.path.isdir(
|
|
os.path.join(checkpoint_dir, cur_dir)):
|
|
return -1
|
|
|
|
success_path = os.path.join(
|
|
_get_serial_dir(checkpoint_dir, serial), MODEL_DIR,
|
|
SUCCESS_MARK_FILENAME)
|
|
if os.path.isfile(success_path):
|
|
return serial
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
return -1
|
|
|
|
current_dir = -1
|
|
dirs = os.listdir(checkpoint_dir)
|
|
for cur_dir in dirs:
|
|
success_num = has_success(checkpoint_dir, cur_dir)
|
|
if success_num > current_dir:
|
|
current_dir = success_num
|
|
return current_dir
|