|
|
|
@ -17,13 +17,12 @@ from __future__ import print_function
|
|
|
|
|
import os
|
|
|
|
|
import collections
|
|
|
|
|
from .. import core
|
|
|
|
|
from ..framework import Variable, Parameter, default_main_program
|
|
|
|
|
from .layers import Layer
|
|
|
|
|
from ..framework import Variable, default_main_program
|
|
|
|
|
|
|
|
|
|
__all__ = ['save_persistables', 'load_persistables']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_persistables(obj, dirname, filename=None):
|
|
|
|
|
def save_persistables(vardict, dirname, filename=None):
|
|
|
|
|
"""
|
|
|
|
|
This function filters out all variables in layer.parameters from the
|
|
|
|
|
give `layer` and then trys to load these variables from the folder
|
|
|
|
@ -35,7 +34,7 @@ def save_persistables(obj, dirname, filename=None):
|
|
|
|
|
the file name.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
var_list(dict of Parameters|Layer): The parameters will
|
|
|
|
|
vardict(dict of Parameters): The parameters will
|
|
|
|
|
be saved. If it is None, nothing
|
|
|
|
|
will be deal.
|
|
|
|
|
dirname(str): The directory path.
|
|
|
|
@ -69,17 +68,14 @@ def save_persistables(obj, dirname, filename=None):
|
|
|
|
|
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
|
|
|
|
|
init_cell)
|
|
|
|
|
param_path = "./my_paddle_model"
|
|
|
|
|
fluid.imperative.checkpoint.save_persistables(ptb_model.parameters(), dirname=param_path,
|
|
|
|
|
fluid.imperative.checkpoint.save_persistables(ptb_model.state_dict(), dirname=param_path,
|
|
|
|
|
layer=ptb_model)
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(obj, collections.OrderedDict):
|
|
|
|
|
_save_var_to_file(obj, dirname, filename)
|
|
|
|
|
elif isinstance(obj, Layer):
|
|
|
|
|
_save_var_to_file(
|
|
|
|
|
obj.state_dict(include_sublayers=True), dirname, filename)
|
|
|
|
|
if isinstance(vardict, collections.OrderedDict):
|
|
|
|
|
_save_var_to_file(vardict, dirname, filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_persistables(obj, dirname, filename=None):
|
|
|
|
|
def load_persistables(vardict, dirname, filename=None):
|
|
|
|
|
"""
|
|
|
|
|
This function trys to load persistable variables from the folder
|
|
|
|
|
`dirname` or the file `filename`.
|
|
|
|
@ -90,7 +86,7 @@ def load_persistables(obj, dirname, filename=None):
|
|
|
|
|
the file name.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
obj(dict of Parameters|Layer): The parameters will be loaded.
|
|
|
|
|
vardict(dict of Parameters): The parameters will be loaded.
|
|
|
|
|
dirname(str): The directory path.
|
|
|
|
|
filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were
|
|
|
|
|
saved in differnet files, set it to None.
|
|
|
|
@ -111,16 +107,13 @@ def load_persistables(obj, dirname, filename=None):
|
|
|
|
|
my_layer = layer(fluid.imperative.Layer)
|
|
|
|
|
param_path = "./my_paddle_model"
|
|
|
|
|
filename = "model.file"
|
|
|
|
|
param_dict = fluid.imperative.checkpoint.load_persistables(my_layer, var_list, param_path,
|
|
|
|
|
param_dict = fluid.imperative.checkpoint.load_persistables(my_layer.state_dict(), param_path,
|
|
|
|
|
filename=filename)
|
|
|
|
|
param_1 = param_dict['PtbModel_0.w_1']
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(obj, collections.OrderedDict):
|
|
|
|
|
return _load_var_from_file(obj, dirname, filename)
|
|
|
|
|
elif isinstance(obj, Layer):
|
|
|
|
|
return _load_var_from_file(
|
|
|
|
|
obj.state_dict(include_sublayers=True), dirname, filename)
|
|
|
|
|
if isinstance(vardict, collections.OrderedDict):
|
|
|
|
|
return _load_var_from_file(vardict, dirname, filename)
|
|
|
|
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|