|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import unittest
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestCheckpoint(unittest.TestCase):
|
|
|
|
@ -35,8 +36,8 @@ class TestCheckpoint(unittest.TestCase):
|
|
|
|
|
trainer_args = ["epoch_id", "step_id"]
|
|
|
|
|
epoch_id, step_id = fluid.io.load_trainer_args(
|
|
|
|
|
self.dirname, serial, self.trainer_id, trainer_args)
|
|
|
|
|
self.assertEqual(self.step_id, step_id)
|
|
|
|
|
self.assertEqual(self.epoch_id, epoch_id)
|
|
|
|
|
self.assertEqual(self.step_id, int(step_id))
|
|
|
|
|
self.assertEqual(self.epoch_id, int(epoch_id))
|
|
|
|
|
|
|
|
|
|
program = fluid.Program()
|
|
|
|
|
with fluid.program_guard(program):
|
|
|
|
@ -44,6 +45,7 @@ class TestCheckpoint(unittest.TestCase):
|
|
|
|
|
fluid.io.load_checkpoint(exe, self.dirname, serial, program)
|
|
|
|
|
|
|
|
|
|
fluid.io.clean_checkpoint(self.dirname, delete_dir=True)
|
|
|
|
|
self.assertFalse(os.path.isdir(self.dirname))
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(self):
|
|
|
|
|
config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints,
|
|
|
|
|