Save checkpoint automatically (#25917)
parent
e853ece0a2
commit
0067a2e4ec
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,223 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..fleet.utils.fs import FS, LocalFS
|
||||
from ..fleet.utils.hdfs import HDFSClient
|
||||
from ...compiler import CompiledProgram
|
||||
|
||||
|
||||
class SerializableBase(object):
|
||||
def serialize(self, path):
|
||||
raise NotImplementedError
|
||||
|
||||
def deserialize(self, path):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PaddleModel(SerializableBase):
|
||||
def __init__(self, exe, program):
|
||||
self._exe = exe
|
||||
self._origin_program = program
|
||||
self._program = program
|
||||
if isinstance(program, CompiledProgram):
|
||||
self._program = program._program
|
||||
|
||||
self._file_name = "_paddle_fleet_param__"
|
||||
|
||||
def serialize(self, path):
|
||||
from ...io import save_persistables
|
||||
save_persistables(
|
||||
executor=self._exe,
|
||||
dirname=path,
|
||||
main_program=self._program,
|
||||
filename=self._file_name)
|
||||
|
||||
def deserialize(self, path):
|
||||
from ...io import load_persistables
|
||||
load_persistables(
|
||||
executor=self._exe,
|
||||
dirname=path,
|
||||
main_program=self._program,
|
||||
filename=self._file_name)
|
||||
|
||||
|
||||
class CheckpointSaver(object):
|
||||
def __init__(self, fs):
|
||||
self._fs = fs
|
||||
self._checkpoint_prefix = "__paddle_checkpoint__"
|
||||
|
||||
def save_checkpoint(self,
|
||||
path,
|
||||
slists,
|
||||
trainer_id=None,
|
||||
local_cache_path=".cache"):
|
||||
"""
|
||||
Serialize objects in slists to path
|
||||
Return really saved path and checkpoint_no
|
||||
"""
|
||||
if not self._fs.is_exist(path):
|
||||
self._fs.mkdirs(path)
|
||||
else:
|
||||
assert self._fs.is_dir(path), "path:{} must be a directory".format(
|
||||
path)
|
||||
|
||||
max_no = self._get_last_checkpoint_no(path)
|
||||
if max_no < 0:
|
||||
max_no = -1
|
||||
max_no += 1
|
||||
|
||||
real_path = "{}/{}.{}".format(path, self._checkpoint_prefix, max_no)
|
||||
tmp_path = "{}.tmp".format(real_path)
|
||||
saved_path = tmp_path
|
||||
|
||||
local_fs = LocalFS()
|
||||
|
||||
cache_path = None
|
||||
if self._fs.need_upload_download():
|
||||
cache_path = "{}/{}.{}.saved_cache".format(
|
||||
local_cache_path, self._checkpoint_prefix, max_no)
|
||||
|
||||
if trainer_id is not None:
|
||||
cache_path = "{}.{}".format(cache_path, trainer_id)
|
||||
|
||||
if not local_fs.is_exist(cache_path):
|
||||
local_fs.mkdirs(cache_path)
|
||||
else:
|
||||
assert local_fs.is_dir(cache_path), \
|
||||
"cache path:{} must be a directory".format(cache_path)
|
||||
|
||||
saved_path = cache_path
|
||||
|
||||
for s in slists:
|
||||
s.serialize(saved_path)
|
||||
|
||||
if self._fs.need_upload_download():
|
||||
self._fs.delete(tmp_path)
|
||||
self._fs.upload(cache_path, tmp_path)
|
||||
local_fs.delete(cache_path)
|
||||
self._fs.mv(tmp_path, real_path)
|
||||
|
||||
return real_path, max_no
|
||||
|
||||
def load_checkpoint(self,
|
||||
path,
|
||||
slists,
|
||||
trainer_id,
|
||||
local_cache_path=".cache",
|
||||
checkpoint_no=None,
|
||||
ignore_empty=True):
|
||||
"""
|
||||
Deserialize objects in slists from path
|
||||
Return really load path
|
||||
"""
|
||||
|
||||
if checkpoint_no is None:
|
||||
max_no = self._get_last_checkpoint_no(path)
|
||||
|
||||
if not ignore_empty:
|
||||
assert max_no >= 0, "Can't find checkpoint"
|
||||
|
||||
if max_no < 0:
|
||||
return None
|
||||
|
||||
checkpoint_no = max_no
|
||||
else:
|
||||
assert isinstance(checkpoint_no, int)
|
||||
assert checkpoint_no >= 0
|
||||
|
||||
local_fs = LocalFS()
|
||||
if self._fs.need_upload_download():
|
||||
cache_path = "{}/{}.{}.load_cache".format(
|
||||
local_cache_path, self._checkpoint_prefix, checkpoint_no)
|
||||
|
||||
if trainer_id is not None:
|
||||
cache_path = "{}.{}".format(cache_path, trainer_id)
|
||||
|
||||
if not local_fs.is_exist(local_cache_path):
|
||||
local_fs.mkdirs(local_cache_path)
|
||||
if local_fs.is_exist(cache_path):
|
||||
local_fs.delete(cache_path)
|
||||
|
||||
real_path = "{}/{}.{}".format(path, self._checkpoint_prefix,
|
||||
checkpoint_no)
|
||||
load_path = real_path
|
||||
if self._fs.need_upload_download():
|
||||
self._fs.download(real_path, cache_path)
|
||||
load_path = cache_path
|
||||
|
||||
for s in slists:
|
||||
s.deserialize(load_path)
|
||||
|
||||
if self._fs.need_upload_download() and cache_path:
|
||||
local_fs.delete(cache_path)
|
||||
|
||||
return real_path
|
||||
|
||||
def get_checkpoint_no(self, root_path):
|
||||
a = []
|
||||
dirs = self._fs.list_dirs(root_path)
|
||||
for d in dirs:
|
||||
g = d.split(".")
|
||||
if len(g) != 2:
|
||||
continue
|
||||
|
||||
if g[0] != self._checkpoint_prefix:
|
||||
continue
|
||||
|
||||
try:
|
||||
n = int(g[1])
|
||||
a.append(n)
|
||||
except:
|
||||
continue
|
||||
|
||||
a.sort()
|
||||
return a
|
||||
|
||||
def _get_last_checkpoint_no(self, root_path):
|
||||
"""
|
||||
only get the first depth
|
||||
"""
|
||||
a = self.get_checkpoint_no(root_path)
|
||||
if len(a) > 0:
|
||||
return a[-1]
|
||||
|
||||
return -1
|
||||
|
||||
def clean_redundant_checkpoints(self, root_path, reserved=[]):
|
||||
max_no = self._get_last_checkpoint_no(root_path)
|
||||
if max_no < 0:
|
||||
return
|
||||
|
||||
s = set(reserved)
|
||||
if len(s) == 0:
|
||||
s.add(max_no)
|
||||
|
||||
dirs = self._fs.list_dirs(root_path)
|
||||
for d in dirs:
|
||||
g = d.split(".")
|
||||
if len(g) != 2:
|
||||
continue
|
||||
|
||||
if g[0] != self._checkpoint_prefix:
|
||||
continue
|
||||
|
||||
try:
|
||||
n = int(g[1])
|
||||
if n not in s:
|
||||
path = "{}/{}.{}".format(root_path, self._checkpoint_prefix,
|
||||
n)
|
||||
self._fs.delete(path)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,131 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
|
||||
import os
|
||||
import sys
|
||||
|
||||
from paddle.fluid.incubate.fleet.utils.fs import LocalFS
|
||||
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
|
||||
import paddle.fluid.incubate.checkpoint.auto_checkpoint as acp
|
||||
from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel
|
||||
from paddle.fluid.framework import program_guard
|
||||
from paddle.fluid import unique_name
|
||||
|
||||
import numpy as np
|
||||
from paddle.io import Dataset, BatchSampler, DataLoader
|
||||
|
||||
BATCH_NUM = 20
|
||||
BATCH_SIZE = 16
|
||||
|
||||
#IMAGE_SIZE = 128
|
||||
CLASS_NUM = 10
|
||||
|
||||
USE_GPU = False # whether use GPU to run model
|
||||
places = fluid.cuda_places() if USE_GPU else fluid.cpu_places()
|
||||
|
||||
logger = None
|
||||
|
||||
|
||||
def get_logger():
|
||||
global logger
|
||||
logger = acp._get_logger(20)
|
||||
return logger
|
||||
|
||||
|
||||
def get_random_images_and_labels(image_shape, label_shape):
|
||||
image = np.random.random(size=image_shape).astype('float32')
|
||||
label = np.random.random(size=label_shape).astype('int64')
|
||||
return image, label
|
||||
|
||||
|
||||
def sample_list_generator_creator():
|
||||
def __reader__():
|
||||
for _ in range(BATCH_NUM):
|
||||
sample_list = []
|
||||
for _ in range(BATCH_SIZE):
|
||||
image, label = get_random_images_and_labels([16, 16], [1])
|
||||
sample_list.append([image, label])
|
||||
|
||||
yield sample_list
|
||||
|
||||
return __reader__
|
||||
|
||||
|
||||
class AutoCheckpointBase(unittest.TestCase):
|
||||
def _init_env(self,
|
||||
exe,
|
||||
main_prog,
|
||||
startup_prog,
|
||||
minimize=True,
|
||||
iterable=True):
|
||||
def simple_net():
|
||||
image = fluid.data(
|
||||
name='image', shape=[-1, 16, 16], dtype='float32')
|
||||
label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
|
||||
|
||||
fc_tmp = fluid.layers.fc(image, size=CLASS_NUM)
|
||||
cross_entropy = fluid.layers.softmax_with_cross_entropy(fc_tmp,
|
||||
label)
|
||||
loss = fluid.layers.reduce_mean(cross_entropy)
|
||||
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
|
||||
if minimize:
|
||||
sgd.minimize(loss)
|
||||
return sgd, loss, image, label
|
||||
|
||||
with program_guard(main_prog, startup_prog):
|
||||
sgd, loss, image, label = simple_net()
|
||||
|
||||
if minimize:
|
||||
compiled = fluid.CompiledProgram(main_prog).with_data_parallel(
|
||||
loss_name=loss.name)
|
||||
else:
|
||||
compiled = None
|
||||
loader = fluid.io.DataLoader.from_generator(
|
||||
feed_list=[image, label],
|
||||
capacity=64,
|
||||
use_double_buffer=True,
|
||||
iterable=iterable)
|
||||
|
||||
loader.set_sample_list_generator(sample_list_generator_creator(),
|
||||
places[0])
|
||||
|
||||
if minimize:
|
||||
exe.run(startup_prog)
|
||||
|
||||
return compiled, loader, sgd, loss, image, label
|
||||
|
||||
def _generate(self):
|
||||
main_prog = fluid.Program()
|
||||
startup_prog = fluid.Program()
|
||||
exe = fluid.Executor(places[0])
|
||||
|
||||
return exe, main_prog, startup_prog
|
||||
|
||||
def _reset_generator(self):
|
||||
unique_name.generator = fluid.unique_name.UniqueNameGenerator()
|
||||
acp.generator = fluid.unique_name.UniqueNameGenerator()
|
||||
acp.g_acp_type = None
|
||||
acp.g_checker = acp.AutoCheckpointChecker()
|
||||
acp.g_program_attr = {}
|
||||
|
||||
def _clear_envs(self):
|
||||
os.environ.pop("PADDLE_RUNNING_ENV", None)
|
||||
|
||||
def _readd_envs(self):
|
||||
os.environ["PADDLE_RUNNING_ENV"] = "PADDLE_EDL_AUTO_CHECKPOINT"
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
|
||||
import os
|
||||
import sys
|
||||
|
||||
from paddle.fluid.incubate.fleet.utils.fs import LocalFS
|
||||
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
|
||||
import paddle.fluid.incubate.checkpoint.auto_checkpoint as acp
|
||||
from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel
|
||||
from paddle.fluid.framework import program_guard
|
||||
from paddle.fluid import unique_name
|
||||
|
||||
import numpy as np
|
||||
from paddle.io import Dataset, BatchSampler, DataLoader
|
||||
|
||||
from paddle.fluid.tests.unittests.auto_checkpoint_utils import AutoCheckpointBase, get_logger
|
||||
from paddle.fluid.tests.unittests.test_auto_checkpoint import AutoCheckPointACLBase
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class AutoCheckpointTest2(AutoCheckPointACLBase):
|
||||
def setUp(self):
|
||||
get_logger()
|
||||
logger.info("enter tests")
|
||||
|
||||
self._old_environ = dict(os.environ)
|
||||
proc_env = {
|
||||
"PADDLE_RUNNING_ENV": "PADDLE_EDL_AUTO_CHECKPOINT",
|
||||
"PADDLE_TRAINER_ID": "0",
|
||||
"PADDLE_RUNNING_PLATFORM": "PADDLE_CLOUD",
|
||||
"PADDLE_JOB_ID": "test_job_auto_2",
|
||||
"PADDLE_EDL_HDFS_HOME": "/usr/local/hadoop-2.7.7",
|
||||
"PADDLE_EDL_HDFS_NAME": "",
|
||||
"PADDLE_EDL_HDFS_UGI": "",
|
||||
"PADDLE_EDL_HDFS_CHECKPOINT_PATH": "auto_checkpoint_2",
|
||||
"PADDLE_EDL_ONLY_FOR_CE_TEST": "1",
|
||||
"PADDLE_EDL_FS_CACHE": ".auto_checkpoint_test_2",
|
||||
"PADDLE_EDL_SAVE_CHECKPOINT_INTER": "0"
|
||||
}
|
||||
os.environ.update(proc_env)
|
||||
|
||||
def test_corner_epoch_no(self):
|
||||
logger.info("begin test_corener_epoch_no")
|
||||
checker = acp._get_checker()
|
||||
fs = HDFSClient(checker.hdfs_home, None)
|
||||
|
||||
for i in range(3):
|
||||
fs.delete(checker.hdfs_checkpoint_path)
|
||||
self._reset_generator()
|
||||
self._run_save_0(break_epoch_no=i)
|
||||
self._reset_generator()
|
||||
self._run_load_0(break_epoch_no=i)
|
||||
|
||||
fs.delete(checker.hdfs_checkpoint_path)
|
||||
logger.info("end test_corener_epoch_no")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,57 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
|
||||
from paddle.fluid.incubate.checkpoint.auto_checkpoint import ExeTrainStatus
|
||||
from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver
|
||||
import os
|
||||
import sys
|
||||
|
||||
from paddle.fluid.incubate.fleet.utils.fs import LocalFS
|
||||
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
|
||||
from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver
|
||||
|
||||
|
||||
class CheckpointerSaverTest(unittest.TestCase):
|
||||
def test(self):
|
||||
fs = HDFSClient("/usr/local/hadoop-2.7.7", None)
|
||||
dir_path = "./checkpointsaver_test"
|
||||
fs.delete(dir_path)
|
||||
|
||||
s = CheckpointSaver(fs)
|
||||
|
||||
fs.mkdirs("{}/exe.exe".format(dir_path))
|
||||
fs.mkdirs("{}/exe.1".format(dir_path))
|
||||
fs.mkdirs("{}/exe".format(dir_path))
|
||||
|
||||
a = s.get_checkpoint_no(dir_path)
|
||||
self.assertEqual(len(a), 0)
|
||||
|
||||
fs.mkdirs("{}/__paddle_checkpoint__.0".format(dir_path))
|
||||
fs.mkdirs("{}/__paddle_checkpoint__.exe".format(dir_path))
|
||||
|
||||
a = s.get_checkpoint_no(dir_path)
|
||||
self.assertEqual(len(a), 1)
|
||||
|
||||
s.clean_redundant_checkpoints(dir_path)
|
||||
s.clean_redundant_checkpoints(dir_path)
|
||||
|
||||
fs.delete(dir_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue