You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/fluid/incubate/checkpoint/checkpoint_saver.py

223 lines
6.6 KiB

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...compiler import CompiledProgram
class SerializableBase(object):
def serialize(self, path):
raise NotImplementedError
def deserialize(self, path):
raise NotImplementedError
class PaddleModel(SerializableBase):
def __init__(self, exe, program):
self._exe = exe
self._origin_program = program
self._program = program
if isinstance(program, CompiledProgram):
self._program = program._program
self._file_name = "_paddle_fleet_param__"
def serialize(self, path):
from ...io import save_persistables
save_persistables(
executor=self._exe,
dirname=path,
main_program=self._program,
filename=self._file_name)
def deserialize(self, path):
from ...io import load_persistables
load_persistables(
executor=self._exe,
dirname=path,
main_program=self._program,
filename=self._file_name)
class CheckpointSaver(object):
def __init__(self, fs):
self._fs = fs
self._checkpoint_prefix = "__paddle_checkpoint__"
def save_checkpoint(self,
path,
slists,
trainer_id=None,
local_cache_path=".cache"):
"""
Serialize objects in slists to path
Return really saved path and checkpoint_no
"""
if not self._fs.is_exist(path):
self._fs.mkdirs(path)
else:
assert self._fs.is_dir(path), "path:{} must be a directory".format(
path)
max_no = self._get_last_checkpoint_no(path)
if max_no < 0:
max_no = -1
max_no += 1
real_path = "{}/{}.{}".format(path, self._checkpoint_prefix, max_no)
tmp_path = "{}.tmp".format(real_path)
saved_path = tmp_path
from paddle.distributed.fleet.utils.fs import LocalFS
local_fs = LocalFS()
cache_path = None
if self._fs.need_upload_download():
cache_path = "{}/{}.{}.saved_cache".format(
local_cache_path, self._checkpoint_prefix, max_no)
if trainer_id is not None:
cache_path = "{}.{}".format(cache_path, trainer_id)
if not local_fs.is_exist(cache_path):
local_fs.mkdirs(cache_path)
else:
assert local_fs.is_dir(cache_path), \
"cache path:{} must be a directory".format(cache_path)
saved_path = cache_path
for s in slists:
s.serialize(saved_path)
if self._fs.need_upload_download():
self._fs.delete(tmp_path)
self._fs.upload(cache_path, tmp_path)
local_fs.delete(cache_path)
self._fs.mv(tmp_path, real_path)
return real_path, max_no
def load_checkpoint(self,
path,
slists,
trainer_id,
local_cache_path=".cache",
checkpoint_no=None,
ignore_empty=True):
"""
Deserialize objects in slists from path
Return really load path
"""
if checkpoint_no is None:
max_no = self._get_last_checkpoint_no(path)
if not ignore_empty:
assert max_no >= 0, "Can't find checkpoint"
if max_no < 0:
return None
checkpoint_no = max_no
else:
assert isinstance(checkpoint_no, int)
assert checkpoint_no >= 0
from paddle.distributed.fleet.utils.fs import LocalFS
local_fs = LocalFS()
if self._fs.need_upload_download():
cache_path = "{}/{}.{}.load_cache".format(
local_cache_path, self._checkpoint_prefix, checkpoint_no)
if trainer_id is not None:
cache_path = "{}.{}".format(cache_path, trainer_id)
if not local_fs.is_exist(local_cache_path):
local_fs.mkdirs(local_cache_path)
if local_fs.is_exist(cache_path):
local_fs.delete(cache_path)
real_path = "{}/{}.{}".format(path, self._checkpoint_prefix,
checkpoint_no)
load_path = real_path
if self._fs.need_upload_download():
self._fs.download(real_path, cache_path)
load_path = cache_path
for s in slists:
s.deserialize(load_path)
if self._fs.need_upload_download() and cache_path:
local_fs.delete(cache_path)
return real_path
def get_checkpoint_no(self, root_path):
a = []
dirs = self._fs.list_dirs(root_path)
for d in dirs:
g = d.split(".")
if len(g) != 2:
continue
if g[0] != self._checkpoint_prefix:
continue
try:
n = int(g[1])
a.append(n)
except:
continue
a.sort()
return a
def _get_last_checkpoint_no(self, root_path):
"""
only get the first depth
"""
a = self.get_checkpoint_no(root_path)
if len(a) > 0:
return a[-1]
return -1
def clean_redundant_checkpoints(self, root_path, reserved=[]):
max_no = self._get_last_checkpoint_no(root_path)
if max_no < 0:
return
s = set(reserved)
if len(s) == 0:
s.add(max_no)
dirs = self._fs.list_dirs(root_path)
for d in dirs:
g = d.split(".")
if len(g) != 2:
continue
if g[0] != self._checkpoint_prefix:
continue
try:
n = int(g[1])
if n not in s:
path = "{}/{}.{}".format(root_path, self._checkpoint_prefix,
n)
self._fs.delete(path)
except Exception as e:
print(e)
continue