@ -24,7 +24,7 @@ import contextlib
from functools import reduce
import numpy as np
import math
import paddle
from paddle . fluid import layers
from paddle . fluid . executor import Executor , global_scope
@ -1710,6 +1710,52 @@ def _load_persistable_nodes(executor, dirname, graph):
load_vars ( executor = executor , dirname = dirname , vars = var_list )
def _unpack_saved_dict ( saved_obj ) :
temp_saved_obj = { }
unpack_infor = { }
for key , value in saved_obj . items ( ) :
if isinstance ( value , np . ndarray ) :
MAX_NUMBER_OF_ELEMENT = 2 * * 22
num_element = np . prod ( value . shape )
if num_element > MAX_NUMBER_OF_ELEMENT :
unpack_infor [ key ] = { }
unpack_infor [ key ] [ " OriginShape " ] = value . shape
unpack_infor [ key ] [ " slices " ] = [ ]
value = value . flatten ( )
for i in range (
int (
math . ceil ( num_element * 1.0 /
MAX_NUMBER_OF_ELEMENT ) ) ) :
part_name = key + " @@. " + str ( i )
unpack_infor [ key ] [ " slices " ] . append ( part_name )
temp_saved_obj [ part_name ] = value [
i * MAX_NUMBER_OF_ELEMENT : MAX_NUMBER_OF_ELEMENT * ( i + 1
) ]
if unpack_infor :
for key , value in unpack_infor . items ( ) :
if key in saved_obj :
saved_obj . pop ( key )
for part in value [ ' slices ' ] :
saved_obj [ part ] = temp_saved_obj [ part ]
saved_obj [ ' UnpackBigParamInfor@@ ' ] = unpack_infor
return saved_obj
def _pack_loaded_dict ( load_obj ) :
unpack_info = ' UnpackBigParamInfor@@ '
if unpack_info in load_obj :
removes = [ ]
for key , value in load_obj [ unpack_info ] . items ( ) :
slices = [ load_obj [ part ] for part in value [ " slices " ] ]
load_obj [ key ] = np . concatenate ( slices ) . reshape ( value [ " OriginShape " ] )
removes + = value [ " slices " ]
for key in removes :
load_obj . pop ( key )
load_obj . pop ( unpack_info )
return load_obj
@static_only
def save ( program , model_path ) :
"""
@ -1762,6 +1808,7 @@ def save(program, model_path):
parameter_list = list ( filter ( is_parameter , program . list_vars ( ) ) )
param_dict = { p . name : get_tensor ( p ) for p in parameter_list }
param_dict = _unpack_saved_dict ( param_dict )
with open ( model_path + " .pdparams " , ' wb ' ) as f :
pickle . dump ( param_dict , f , protocol = 2 )
@ -1935,6 +1982,7 @@ def load(program, model_path, executor=None, var_list=None):
with open ( parameter_file_name , ' rb ' ) as f :
load_dict = pickle . load ( f ) if six . PY2 else pickle . load (
f , encoding = ' latin1 ' )
load_dict = _pack_loaded_dict ( load_dict )
for v in parameter_list :
assert v . name in load_dict , \
" Can not find [ {} ] in model file [ {} ] " . format (