@ -139,7 +139,7 @@ class Context(object):
"""
"""
Load the context from file .
Load the context from file .
"""
"""
with open ( file_name ) as context_file :
with open ( file_name , ' rb ' ) as context_file :
if sys . version_info < ( 3 , 0 ) :
if sys . version_info < ( 3 , 0 ) :
data = pickle . load ( context_file )
data = pickle . load ( context_file )
else :
else :
@ -242,6 +242,9 @@ class Compressor(object):
eval_reader = None ,
eval_reader = None ,
eval_feed_list = None ,
eval_feed_list = None ,
eval_fetch_list = None ,
eval_fetch_list = None ,
eval_func = None ,
save_eval_model = True ,
prune_infer_model = None ,
teacher_programs = [ ] ,
teacher_programs = [ ] ,
checkpoint_path = None ,
checkpoint_path = None ,
train_optimizer = None ,
train_optimizer = None ,
@ -260,13 +263,28 @@ class Compressor(object):
The key is user - defined and human - readable name .
The key is user - defined and human - readable name .
The value is the name of Variable .
The value is the name of Variable .
eval_program ( Program ) : The program used for evaluation .
eval_program ( Program ) : The program used for evaluation .
eval_reader : The data reader used for evaluation .
eval_reader : The data reader used for evaluation . It can be None if eval_func is not None .
eval_feed_list ( dict ) : A dict to indicate the input variable of the evaluation program .
eval_feed_list ( dict ) : A dict to indicate the input variable of the evaluation program .
The key is user - defined and human - readable name .
The key is user - defined and human - readable name .
The value is the name of Variable .
The value is the name of Variable .
It can be None if eval_func is not None .
eval_fetch_list ( dict ) : A dict to indicate the output variable of the evaluation program .
eval_fetch_list ( dict ) : A dict to indicate the output variable of the evaluation program .
The key is user - defined and human - readable name .
The key is user - defined and human - readable name .
The value is the name of Variable .
The value is the name of Variable .
eval_func ( dict | function ) : Callback functions used to evaluate the compressed model .
The eval_func is a dict , the key is user - defined name and the value is
a callback function . And the score returned from callback functions
can be referenced in config file by the key of eval_func .
The args of callback function are compressed eval_program and scope which
store the compressed parameters .
Default : None .
save_eval_model ( bool ) : Whether to save eval model when saving checkpoints . Default : True .
prune_infer_model ( tuple | list ) : If prune_infer_model is not None , compressor will prune
eval program into inference program according to inputs and outputs
defined in prune_infer_model . prune_infer_model [ 0 ] is a list of input
variables ' names and prune_infer_model[1] is a list of output variables '
names . If prune_infer_model is None , it will not save inference model .
Default : None .
teacher_programs : The teacher graphs used in distillation strategies .
teacher_programs : The teacher graphs used in distillation strategies .
train_optimizer : The optimizer used to append backward ops and
train_optimizer : The optimizer used to append backward ops and
optimization ops into train_graph .
optimization ops into train_graph .
@ -294,6 +312,10 @@ class Compressor(object):
eval_program , in_nodes = eval_feed_list , out_nodes = eval_fetch_list )
eval_program , in_nodes = eval_feed_list , out_nodes = eval_fetch_list )
self . train_reader = train_reader
self . train_reader = train_reader
self . eval_reader = eval_reader
self . eval_reader = eval_reader
self . eval_func = eval_func
self . save_eval_model = save_eval_model
self . prune_infer_model = prune_infer_model
self . teacher_graphs = [ ]
self . teacher_graphs = [ ]
for teacher in teacher_programs :
for teacher in teacher_programs :
self . teacher_graphs . append ( GraphWrapper ( teacher ) )
self . teacher_graphs . append ( GraphWrapper ( teacher ) )
@ -393,6 +415,9 @@ class Compressor(object):
strategies = pickle . load (
strategies = pickle . load (
strategy_file , encoding = ' bytes ' )
strategy_file , encoding = ' bytes ' )
for s , s1 in zip ( self . strategies , strategies ) :
s1 . __dict__ . update ( s . __dict__ )
for strategy in strategies :
for strategy in strategies :
strategy . restore_from_checkpoint ( context )
strategy . restore_from_checkpoint ( context )
@ -401,10 +426,6 @@ class Compressor(object):
with scope_guard ( context . scope ) :
with scope_guard ( context . scope ) :
context . optimize_graph . load_persistables ( model_path ,
context . optimize_graph . load_persistables ( model_path ,
exe )
exe )
context . optimize_graph . update_param_shape ( context . scope )
context . optimize_graph . update_groups_of_conv ( )
context . eval_graph . update_param_shape ( context . scope )
context . eval_graph . update_groups_of_conv ( )
_logger . info ( " Loaded params from: {} " . format ( model_path ) )
_logger . info ( " Loaded params from: {} " . format ( model_path ) )
return context , strategies
return context , strategies
@ -416,6 +437,7 @@ class Compressor(object):
checkpoint_path = os . path . join ( self . checkpoint_path ,
checkpoint_path = os . path . join ( self . checkpoint_path ,
str ( context . epoch_id ) )
str ( context . epoch_id ) )
model_path = os . path . join ( checkpoint_path , ' model ' )
model_path = os . path . join ( checkpoint_path , ' model ' )
eval_model_path = os . path . join ( checkpoint_path , ' eval_model ' )
context_path = os . path . join ( checkpoint_path , ' context ' )
context_path = os . path . join ( checkpoint_path , ' context ' )
strategy_path = os . path . join ( checkpoint_path , ' strategies ' )
strategy_path = os . path . join ( checkpoint_path , ' strategies ' )
if not os . path . isdir ( model_path ) :
if not os . path . isdir ( model_path ) :
@ -423,6 +445,15 @@ class Compressor(object):
exe = SlimGraphExecutor ( context . place )
exe = SlimGraphExecutor ( context . place )
with scope_guard ( context . scope ) :
with scope_guard ( context . scope ) :
context . optimize_graph . save_persistables ( model_path , exe )
context . optimize_graph . save_persistables ( model_path , exe )
if self . save_eval_model :
context . eval_graph . save_model ( eval_model_path , exe )
if self . prune_infer_model :
context . eval_graph . save_infer_model (
eval_model_path ,
exe ,
self . prune_infer_model ,
program_only = self . save_eval_model )
context . to_file ( context_path )
context . to_file ( context_path )
with open ( strategy_path , ' wb ' ) as strategy_file :
with open ( strategy_path , ' wb ' ) as strategy_file :
pickle . dump ( self . strategies , strategy_file )
pickle . dump ( self . strategies , strategy_file )
@ -485,11 +516,19 @@ class Compressor(object):
"""
"""
Runing evaluation .
Runing evaluation .
"""
"""
results , names = context . run_eval_graph ( )
if self . eval_func is not None :
for name , result in zip ( names , results ) :
for key in self . eval_func :
if name not in context . eval_results :
func = self . eval_func [ key ]
context . eval_results [ name ] = [ ]
if key not in context . eval_results :
context . eval_results [ name ] . append ( result )
context . eval_results [ key ] = [ ]
context . eval_results [ key ] . append (
func ( self . eval_graph . program , self . scope ) )
else :
results , names = context . run_eval_graph ( )
for name , result in zip ( names , results ) :
if name not in context . eval_results :
context . eval_results [ name ] = [ ]
context . eval_results [ name ] . append ( result )
def run ( self ) :
def run ( self ) :
"""
"""