parent
5bb04ea47d
commit
bed0ecf3d2
@ -0,0 +1,194 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import collections
|
||||||
|
from .. import core
|
||||||
|
from ..framework import Variable, Parameter, default_main_program
|
||||||
|
from .layers import Layer
|
||||||
|
|
||||||
|
__all__ = ['save_persistables', 'load_persistables']
|
||||||
|
|
||||||
|
|
||||||
|
def save_persistables(obj, dirname, filename=None):
|
||||||
|
"""
|
||||||
|
This function filters out all variables in layer.parameters from the
|
||||||
|
give `layer` and then trys to load these variables from the folder
|
||||||
|
`dirname` or the file `filename`.
|
||||||
|
|
||||||
|
Use the `dirname` to specify the folder where persistable variables were
|
||||||
|
saved. If variables were saved in separate files, set `filename` None;
|
||||||
|
if all variables were saved in a single file, use `filename` to specify
|
||||||
|
the file name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
var_list(dict of Parameters|Layer): The parameters will
|
||||||
|
be saved. If it is None, nothing
|
||||||
|
will be deal.
|
||||||
|
dirname(str): The directory path.
|
||||||
|
filename(str|None): The file which saved all variables. If variables were
|
||||||
|
saved in differnet files, set it to None.
|
||||||
|
Default: None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
ptb_model = PtbModel(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_steps=num_steps,
|
||||||
|
init_scale=init_scale)
|
||||||
|
|
||||||
|
x_data = np.arange(12).reshape(4, 3).astype('int64')
|
||||||
|
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
|
||||||
|
x_data = x_data.reshape((-1, num_steps, 1))
|
||||||
|
y_data = y_data.reshape((-1, 1))
|
||||||
|
init_hidden_data = np.zeros(
|
||||||
|
(num_layers, batch_size, hidden_size), dtype='float32')
|
||||||
|
init_cell_data = np.zeros(
|
||||||
|
(num_layers, batch_size, hidden_size), dtype='float32')
|
||||||
|
x = to_variable(x_data)
|
||||||
|
y = to_variable(y_data)
|
||||||
|
init_hidden = to_variable(init_hidden_data)
|
||||||
|
init_cell = to_variable(init_cell_data)
|
||||||
|
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
|
||||||
|
init_cell)
|
||||||
|
param_path = "./my_paddle_model"
|
||||||
|
fluid.imperative.checkpoint.save_persistables(ptb_model.parameters(), dirname=param_path,
|
||||||
|
layer=ptb_model)
|
||||||
|
"""
|
||||||
|
if isinstance(obj, collections.OrderedDict):
|
||||||
|
_save_var_to_file(obj, dirname, filename)
|
||||||
|
elif isinstance(obj, Layer):
|
||||||
|
_save_var_to_file(
|
||||||
|
obj.state_dict(include_sublayers=True), dirname, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def load_persistables(obj, dirname, filename=None):
|
||||||
|
"""
|
||||||
|
This function trys to load persistable variables from the folder
|
||||||
|
`dirname` or the file `filename`.
|
||||||
|
|
||||||
|
Use the `dirname` to specify the folder where persistable variables were
|
||||||
|
saved. If variables were saved in separate files, set `filename` None;
|
||||||
|
if all variables were saved in a single file, use `filename` to specify
|
||||||
|
the file name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj(dict of Parameters|Layer): The parameters will be loaded.
|
||||||
|
dirname(str): The directory path.
|
||||||
|
filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were
|
||||||
|
saved in differnet files, set it to None.
|
||||||
|
Default: None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The parameter-dict resumed from file
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
my_layer = layer(fluid.imperative.Layer)
|
||||||
|
param_path = "./my_paddle_model"
|
||||||
|
|
||||||
|
param_dict = fluid.imperative.checkpoint.load_persistables(my_layer.parameters(), param_path)
|
||||||
|
param_1 = param_dict['PtbModel_0.w_1']
|
||||||
|
|
||||||
|
or:
|
||||||
|
my_layer = layer(fluid.imperative.Layer)
|
||||||
|
param_path = "./my_paddle_model"
|
||||||
|
filename = "model.file"
|
||||||
|
param_dict = fluid.imperative.checkpoint.load_persistables(my_layer, var_list, param_path,
|
||||||
|
filename=filename)
|
||||||
|
param_1 = param_dict['PtbModel_0.w_1']
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(obj, collections.OrderedDict):
|
||||||
|
return _load_var_from_file(obj, dirname, filename)
|
||||||
|
elif isinstance(obj, Layer):
|
||||||
|
return _load_var_from_file(
|
||||||
|
obj.state_dict(include_sublayers=True), dirname, filename)
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _save_var_to_file(stat_dict, file_dir, file_name):
|
||||||
|
save_block = default_main_program().global_block()
|
||||||
|
save_var_map = {}
|
||||||
|
for each_var in stat_dict.items():
|
||||||
|
save_var_map[each_var.name] = each_var
|
||||||
|
if file_name is None:
|
||||||
|
save_block.append_op(
|
||||||
|
type='save',
|
||||||
|
inputs={'X': [each_var]},
|
||||||
|
outputs={},
|
||||||
|
attrs={'file_path': os.path.join(file_dir, each_var.name)})
|
||||||
|
|
||||||
|
if file_name is not None:
|
||||||
|
save_var_list = []
|
||||||
|
for name in sorted(save_var_map.keys()):
|
||||||
|
save_var_list.append(save_var_map[name])
|
||||||
|
|
||||||
|
save_block.append_op(
|
||||||
|
type='save_combine',
|
||||||
|
inputs={'X': save_var_list},
|
||||||
|
outputs={},
|
||||||
|
attrs={'file_path': os.path.join(file_dir, file_name)})
|
||||||
|
|
||||||
|
|
||||||
|
def _load_var_from_file(stat_dict, file_dir, file_name):
|
||||||
|
load_block = default_main_program().global_block()
|
||||||
|
load_var_map = {}
|
||||||
|
|
||||||
|
for each_var in stat_dict.items():
|
||||||
|
assert isinstance(each_var, Variable)
|
||||||
|
if each_var.type == core.VarDesc.VarType.RAW:
|
||||||
|
continue
|
||||||
|
new_var = _clone_var_in_block_(load_block, each_var)
|
||||||
|
if file_name is None:
|
||||||
|
load_block.append_op(
|
||||||
|
type='load',
|
||||||
|
inputs={},
|
||||||
|
outputs={'Out': [new_var]},
|
||||||
|
attrs={'file_path': os.path.join(file_dir, each_var.name)})
|
||||||
|
|
||||||
|
load_var_map[new_var.name] = new_var
|
||||||
|
|
||||||
|
if file_name is not None:
|
||||||
|
load_var_list = []
|
||||||
|
for name in sorted(load_var_map.keys()):
|
||||||
|
load_var_list.append(load_var_map[name])
|
||||||
|
|
||||||
|
load_block.append_op(
|
||||||
|
type='load_combine',
|
||||||
|
inputs={},
|
||||||
|
outputs={"Out": load_var_list},
|
||||||
|
attrs={'file_path': os.path.join(file_dir, file_name)})
|
||||||
|
for res_var in load_var_list:
|
||||||
|
load_var_map[res_var.name] = res_var
|
||||||
|
|
||||||
|
return load_var_map
|
||||||
|
|
||||||
|
|
||||||
|
def _clone_var_in_block_(block, var):
|
||||||
|
assert isinstance(var, Variable)
|
||||||
|
return block.create_var(
|
||||||
|
name=var.name,
|
||||||
|
shape=var.shape,
|
||||||
|
dtype=var.dtype,
|
||||||
|
type=var.type,
|
||||||
|
lod_level=var.lod_level,
|
||||||
|
persistable=True)
|
@ -0,0 +1,163 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.fluid.optimizer import SGDOptimizer
|
||||||
|
from paddle.fluid.imperative.nn import Conv2D, Pool2D, FC
|
||||||
|
from paddle.fluid.imperative.base import to_variable
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleImgConvPool(fluid.imperative.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
name_scope,
|
||||||
|
num_channels,
|
||||||
|
num_filters,
|
||||||
|
filter_size,
|
||||||
|
pool_size,
|
||||||
|
pool_stride,
|
||||||
|
pool_padding=0,
|
||||||
|
pool_type='max',
|
||||||
|
global_pooling=False,
|
||||||
|
conv_stride=1,
|
||||||
|
conv_padding=0,
|
||||||
|
conv_dilation=1,
|
||||||
|
conv_groups=1,
|
||||||
|
act=None,
|
||||||
|
use_cudnn=False,
|
||||||
|
param_attr=None,
|
||||||
|
bias_attr=None):
|
||||||
|
super(SimpleImgConvPool, self).__init__(name_scope)
|
||||||
|
|
||||||
|
self._conv2d = Conv2D(
|
||||||
|
self.full_name(),
|
||||||
|
num_channels=num_channels,
|
||||||
|
num_filters=num_filters,
|
||||||
|
filter_size=filter_size,
|
||||||
|
stride=conv_stride,
|
||||||
|
padding=conv_padding,
|
||||||
|
dilation=conv_dilation,
|
||||||
|
groups=conv_groups,
|
||||||
|
param_attr=None,
|
||||||
|
bias_attr=None,
|
||||||
|
use_cudnn=use_cudnn)
|
||||||
|
|
||||||
|
self._pool2d = Pool2D(
|
||||||
|
self.full_name(),
|
||||||
|
pool_size=pool_size,
|
||||||
|
pool_type=pool_type,
|
||||||
|
pool_stride=pool_stride,
|
||||||
|
pool_padding=pool_padding,
|
||||||
|
global_pooling=global_pooling,
|
||||||
|
use_cudnn=use_cudnn)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
x = self._conv2d(inputs)
|
||||||
|
x = self._pool2d(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MNIST(fluid.imperative.Layer):
|
||||||
|
def __init__(self, name_scope):
|
||||||
|
super(MNIST, self).__init__(name_scope)
|
||||||
|
|
||||||
|
self._simple_img_conv_pool_1 = SimpleImgConvPool(
|
||||||
|
self.full_name(), 1, 20, 5, 2, 2, act="relu")
|
||||||
|
|
||||||
|
self._simple_img_conv_pool_2 = SimpleImgConvPool(
|
||||||
|
self.full_name(), 20, 50, 5, 2, 2, act="relu")
|
||||||
|
|
||||||
|
pool_2_shape = 50 * 4 * 4
|
||||||
|
SIZE = 10
|
||||||
|
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
|
||||||
|
self._fc = FC(self.full_name(),
|
||||||
|
10,
|
||||||
|
param_attr=fluid.param_attr.ParamAttr(
|
||||||
|
initializer=fluid.initializer.NormalInitializer(
|
||||||
|
loc=0.0, scale=scale)),
|
||||||
|
act="softmax")
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
x = self._simple_img_conv_pool_1(inputs)
|
||||||
|
x = self._simple_img_conv_pool_2(x)
|
||||||
|
x = self._fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TestImperativeCheckpoint(unittest.TestCase):
|
||||||
|
def save_load_persistables(self):
|
||||||
|
seed = 90
|
||||||
|
epoch_num = 1
|
||||||
|
|
||||||
|
with fluid.imperative.guard():
|
||||||
|
fluid.default_startup_program().random_seed = seed
|
||||||
|
fluid.default_main_program().random_seed = seed
|
||||||
|
|
||||||
|
mnist = MNIST("mnist")
|
||||||
|
sgd = SGDOptimizer(learning_rate=1e-3)
|
||||||
|
train_reader = paddle.batch(
|
||||||
|
paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
|
||||||
|
|
||||||
|
dy_param_init_value = {}
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
for epoch in range(epoch_num):
|
||||||
|
for batch_id, data in enumerate(train_reader()):
|
||||||
|
dy_x_data = np.array(
|
||||||
|
[x[0].reshape(1, 28, 28)
|
||||||
|
for x in data]).astype('float32')
|
||||||
|
y_data = np.array(
|
||||||
|
[x[1] for x in data]).astype('int64').reshape(128, 1)
|
||||||
|
|
||||||
|
img = to_variable(dy_x_data)
|
||||||
|
label = to_variable(y_data)
|
||||||
|
label._stop_gradient = True
|
||||||
|
|
||||||
|
cost = mnist(img)
|
||||||
|
loss = fluid.layers.cross_entropy(cost, label)
|
||||||
|
avg_loss = fluid.layers.mean(loss)
|
||||||
|
|
||||||
|
dy_out = avg_loss._numpy()
|
||||||
|
|
||||||
|
avg_loss._backward()
|
||||||
|
sgd.minimize(avg_loss)
|
||||||
|
fluid.imperative.save_persistables(mnist, "save_dir")
|
||||||
|
mnist.clear_gradients()
|
||||||
|
|
||||||
|
for param in mnist.parameters():
|
||||||
|
dy_param_init_value[param.name] = param._numpy()
|
||||||
|
|
||||||
|
mnist.load_dict(
|
||||||
|
fluid.imperative.load_persistables(mnist, "save_dir"))
|
||||||
|
|
||||||
|
restore = mnist.parameters()
|
||||||
|
|
||||||
|
self.assertEqual(len(dy_param_init_value), len(restore))
|
||||||
|
for value in restore:
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(value, dy_param_init_value[value.name]))
|
||||||
|
self.assertTrue(np.isfinite(value.all()))
|
||||||
|
self.assertFalse(np.isnan(value.any()))
|
||||||
|
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
if step > 20:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue