|
|
|
@ -20,6 +20,7 @@ import warnings
|
|
|
|
|
import time
|
|
|
|
|
import shutil
|
|
|
|
|
import six
|
|
|
|
|
import logging
|
|
|
|
|
from functools import reduce
|
|
|
|
|
|
|
|
|
|
from paddle.fluid import layers
|
|
|
|
@ -29,12 +30,17 @@ from paddle.fluid.framework import Program, Parameter, default_main_program, def
|
|
|
|
|
from . import reader
|
|
|
|
|
from .reader import *
|
|
|
|
|
from . import core
|
|
|
|
|
from .. import compat as cpt
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
|
|
|
|
|
'load_persistables', 'save_inference_model', 'load_inference_model'
|
|
|
|
|
] + reader.__all__
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
_logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_parameter(var):
|
|
|
|
|
"""
|
|
|
|
@ -1181,3 +1187,80 @@ def get_parameter_value_by_name(name, executor, program=None):
|
|
|
|
|
program = default_main_program()
|
|
|
|
|
var = program.global_block().var(name)
|
|
|
|
|
return get_parameter_value(var, executor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_persistable_nodes(executor, dirname, graph):
|
|
|
|
|
"""
|
|
|
|
|
Save persistable nodes to the given directory by the executor.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
executor(Executor): The executor to run for saving node values.
|
|
|
|
|
dirname(str): The directory path.
|
|
|
|
|
graph(IrGraph): All the required persistable nodes in the graph will be saved.
|
|
|
|
|
"""
|
|
|
|
|
persistable_node_names = set()
|
|
|
|
|
persistable_nodes = []
|
|
|
|
|
all_persistable_nodes = graph.all_persistable_nodes()
|
|
|
|
|
for node in all_persistable_nodes:
|
|
|
|
|
name = cpt.to_text(node.name())
|
|
|
|
|
if name not in persistable_node_names:
|
|
|
|
|
persistable_node_names.add(name)
|
|
|
|
|
persistable_nodes.append(node)
|
|
|
|
|
program = Program()
|
|
|
|
|
var_list = []
|
|
|
|
|
for node in persistable_nodes:
|
|
|
|
|
var_desc = node.var()
|
|
|
|
|
if var_desc.type() == core.VarDesc.VarType.RAW or \
|
|
|
|
|
var_desc.type() == core.VarDesc.VarType.READER:
|
|
|
|
|
continue
|
|
|
|
|
var = program.global_block().create_var(
|
|
|
|
|
name=var_desc.name(),
|
|
|
|
|
shape=var_desc.shape(),
|
|
|
|
|
dtype=var_desc.dtype(),
|
|
|
|
|
type=var_desc.type(),
|
|
|
|
|
lod_level=var_desc.lod_level(),
|
|
|
|
|
persistable=var_desc.persistable())
|
|
|
|
|
var_list.append(var)
|
|
|
|
|
save_vars(executor=executor, dirname=dirname, vars=var_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_persistable_nodes(executor, dirname, graph):
|
|
|
|
|
"""
|
|
|
|
|
Load persistable node values from the given directory by the executor.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
executor(Executor): The executor to run for loading node values.
|
|
|
|
|
dirname(str): The directory path.
|
|
|
|
|
graph(IrGraph): All the required persistable nodes in the graph will be loaded.
|
|
|
|
|
"""
|
|
|
|
|
persistable_node_names = set()
|
|
|
|
|
persistable_nodes = []
|
|
|
|
|
all_persistable_nodes = graph.all_persistable_nodes()
|
|
|
|
|
for node in all_persistable_nodes:
|
|
|
|
|
name = cpt.to_text(node.name())
|
|
|
|
|
if name not in persistable_node_names:
|
|
|
|
|
persistable_node_names.add(name)
|
|
|
|
|
persistable_nodes.append(node)
|
|
|
|
|
program = Program()
|
|
|
|
|
var_list = []
|
|
|
|
|
|
|
|
|
|
def _exist(var):
|
|
|
|
|
return os.path.exists(os.path.join(dirname, var.name))
|
|
|
|
|
|
|
|
|
|
for node in persistable_nodes:
|
|
|
|
|
var_desc = node.var()
|
|
|
|
|
if var_desc.type() == core.VarDesc.VarType.RAW or \
|
|
|
|
|
var_desc.type() == core.VarDesc.VarType.READER:
|
|
|
|
|
continue
|
|
|
|
|
var = program.global_block().create_var(
|
|
|
|
|
name=var_desc.name(),
|
|
|
|
|
shape=var_desc.shape(),
|
|
|
|
|
dtype=var_desc.dtype(),
|
|
|
|
|
type=var_desc.type(),
|
|
|
|
|
lod_level=var_desc.lod_level(),
|
|
|
|
|
persistable=var_desc.persistable())
|
|
|
|
|
if _exist(var):
|
|
|
|
|
var_list.append(var)
|
|
|
|
|
else:
|
|
|
|
|
_logger.warn("Cannot find the var %s!!!" % (node.name()))
|
|
|
|
|
load_vars(executor=executor, dirname=dirname, vars=var_list)
|
|
|
|
|