@ -24,7 +24,7 @@ import contextlib
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					from  functools  import  reduce 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					import  numpy  as  np 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					import  math 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					import  paddle 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					from  paddle . fluid  import  layers 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					from  paddle . fluid . executor  import  Executor ,  global_scope 
 
				
			 
			
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
				 
				 
				
					@ -1710,6 +1710,52 @@ def _load_persistable_nodes(executor, dirname, graph):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    load_vars ( executor = executor ,  dirname = dirname ,  vars = var_list ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					def  _unpack_saved_dict ( saved_obj ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    temp_saved_obj  =  { } 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    unpack_infor  =  { } 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    for  key ,  value  in  saved_obj . items ( ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        if  isinstance ( value ,  np . ndarray ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            MAX_NUMBER_OF_ELEMENT  =  2 * * 22 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            num_element  =  np . prod ( value . shape ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            if  num_element  >  MAX_NUMBER_OF_ELEMENT : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                unpack_infor [ key ]  =  { } 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                unpack_infor [ key ] [ " OriginShape " ]  =  value . shape 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                unpack_infor [ key ] [ " slices " ]  =  [ ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                value  =  value . flatten ( ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                for  i  in  range ( 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                        int ( 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                            math . ceil ( num_element  *  1.0  / 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                      MAX_NUMBER_OF_ELEMENT ) ) ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                    part_name  =  key  +  " @@. "  +  str ( i ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                    unpack_infor [ key ] [ " slices " ] . append ( part_name ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                    temp_saved_obj [ part_name ]  =  value [ 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                        i  *  MAX_NUMBER_OF_ELEMENT : MAX_NUMBER_OF_ELEMENT  *  ( i  +  1 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                                                           ) ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    if  unpack_infor : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        for  key ,  value  in  unpack_infor . items ( ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            if  key  in  saved_obj : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                saved_obj . pop ( key ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                for  part  in  value [ ' slices ' ] : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                    saved_obj [ part ]  =  temp_saved_obj [ part ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        saved_obj [ ' UnpackBigParamInfor@@ ' ]  =  unpack_infor 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    return  saved_obj 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					def  _pack_loaded_dict ( load_obj ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    unpack_info  =  ' UnpackBigParamInfor@@ ' 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    if  unpack_info  in  load_obj : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        removes  =  [ ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        for  key ,  value  in  load_obj [ unpack_info ] . items ( ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            slices  =  [ load_obj [ part ]  for  part  in  value [ " slices " ] ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            load_obj [ key ]  =  np . concatenate ( slices ) . reshape ( value [ " OriginShape " ] ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            removes  + =  value [ " slices " ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        for  key  in  removes : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            load_obj . pop ( key ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        load_obj . pop ( unpack_info ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    return  load_obj 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					@static_only 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					def  save ( program ,  model_path ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    """ 
 
				
			 
			
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
				 
				 
				
					@ -1762,6 +1808,7 @@ def save(program, model_path):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    parameter_list  =  list ( filter ( is_parameter ,  program . list_vars ( ) ) ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    param_dict  =  { p . name :  get_tensor ( p )  for  p  in  parameter_list } 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    param_dict  =  _unpack_saved_dict ( param_dict ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    with  open ( model_path  +  " .pdparams " ,  ' wb ' )  as  f : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        pickle . dump ( param_dict ,  f ,  protocol = 2 ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
				 
				 
				
					@ -1935,6 +1982,7 @@ def load(program, model_path, executor=None, var_list=None):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    with  open ( parameter_file_name ,  ' rb ' )  as  f : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        load_dict  =  pickle . load ( f )  if  six . PY2  else  pickle . load ( 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            f ,  encoding = ' latin1 ' ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        load_dict  =  _pack_loaded_dict ( load_dict ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    for  v  in  parameter_list : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        assert  v . name  in  load_dict ,  \
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            " Can not find [ {} ] in model file [ {} ] " . format (