@ -395,7 +395,11 @@ class Block(object):
return v
def all_parameters ( self ) :
return { v for k , v in self . vars . iteritems ( ) if isinstance ( v , Parameter ) }
return list ( self . iter_parameters ( ) )
def iter_parameters ( self ) :
return ( item [ 1 ] for item in self . vars . iteritems ( )
if isinstance ( item [ 1 ] , Parameter ) )
def create_var ( self , * args , * * kwargs ) :
var = Variable ( self , * args , * * kwargs )
@ -469,6 +473,37 @@ class Block(object):
for index in range ( len ( self . ops ) ) :
assert self . ops [ index ] . desc == ops_in_cpp [ index ]
def copy_param_info_from ( self , other ) :
"""
Copy the information of parameters from other block
Args :
other ( Block ) : other block
Returns :
None
"""
if not isinstance ( other , Block ) :
raise TypeError ( " copy_param_info_from should be invoked with Block " )
for p in other . iter_parameters ( ) :
assert isinstance ( p , Parameter )
v = self . vars . get ( p . name , None )
if v is None :
raise ValueError ( " copy_param_info_from should be invoked with "
" same topology " )
assert isinstance ( v , Variable )
new_p = Parameter (
block = self ,
shape = v . shape ,
dtype = v . dtype ,
type = v . type ,
lod_level = v . lod_level ,
stop_gradient = p . stop_gradient ,
trainable = p . trainable ,
optimize_attr = p . optimize_attr ,
regularizer = p . regularizer ,
name = v . name )
self . vars [ new_p . name ] = new_p
class Program ( object ) :
def __init__ ( self ) :
@ -489,6 +524,7 @@ class Program(object):
p . desc = core . ProgramDesc ( self . desc )
p . blocks = [ Block ( p , i ) for i in xrange ( self . desc . num_blocks ( ) ) ]
p . sync_with_cpp ( )
p . copy_param_info_from ( self )
return p
def prune ( self , targets ) :
@ -572,6 +608,24 @@ class Program(object):
for block in self . blocks :
block . sync_with_cpp ( )
def copy_param_info_from ( self , other ) :
"""
Copy the information of parameters from other program .
Args :
other ( Program ) : Other program
Returns :
None
"""
if not isinstance ( other , Program ) :
raise TypeError ( " copy_param_info_from should be invoked with "
" Program " )
if len ( self . blocks ) != len ( other . blocks ) :
raise ValueError ( " copy_param_info_from should be invoked with two "
" program, with represent the same topology " )
self . global_block ( ) . copy_param_info_from ( other . global_block ( ) )
def list_vars ( self ) :
for each_block in self . blocks :
for each_var in each_block . vars . itervalues ( ) :