@ -201,8 +201,11 @@ def prepare_distributed_context(place=None):
def _update_input_shapes ( inputs ) :
" Get input shape list by given inputs in Model initialization. "
shapes = None
if isinstance ( inputs , list ) :
if isinstance ( inputs , Input ) :
shapes = [ list ( inputs . shape ) ]
elif isinstance ( inputs , list ) :
shapes = [ list ( input . shape ) for input in inputs ]
elif isinstance ( inputs , dict ) :
shapes = [ list ( inputs [ name ] . shape ) for name in inputs ]
@ -917,9 +920,7 @@ class Model(object):
"""
loss = self . _adapter . train_batch ( inputs , labels )
if fluid . in_dygraph_mode ( ) and self . _input_shapes is None :
self . _input_shapes = self . _adapter . _input_shapes
self . _is_shape_inferred = True
self . _inputs = self . _verify_spec ( None , self . _input_shapes , True )
self . _update_inputs ( )
return loss
def eval_batch ( self , inputs , labels = None ) :
@ -967,9 +968,7 @@ class Model(object):
"""
loss = self . _adapter . eval_batch ( inputs , labels )
if fluid . in_dygraph_mode ( ) and self . _input_shapes is None :
self . _input_shapes = self . _adapter . _input_shapes
self . _is_shape_inferred = True
self . _inputs = self . _verify_spec ( None , self . _input_shapes , True )
self . _update_inputs ( )
return loss
def test_batch ( self , inputs ) :
@ -1012,9 +1011,7 @@ class Model(object):
"""
loss = self . _adapter . test_batch ( inputs )
if fluid . in_dygraph_mode ( ) and self . _input_shapes is None :
self . _input_shapes = self . _adapter . _input_shapes
self . _is_shape_inferred = True
self . _inputs = self . _verify_spec ( None , self . _input_shapes , True )
self . _update_inputs ( )
return loss
def save ( self , path , training = True ) :
@ -1707,7 +1704,7 @@ class Model(object):
layer = self . network
if self . _input_shapes is None : # No provided or inferred
raise RuntimeError (
" Saving inference model needs ' inputs ' or running before saving. Please specify ' inputs ' in Model initialization or input training zqq data and perform a training for shape derivation."
" Saving inference model needs ' inputs ' or running before saving. Please specify ' inputs ' in Model initialization or input training data and perform a training for shape derivation."
)
if self . _is_shape_inferred :
warnings . warn (
@ -1953,3 +1950,9 @@ class Model(object):
except Exception :
steps = None
return steps
def _update_inputs ( self ) :
" Update self._inputs according to given inputs. "
self . _input_shapes = self . _adapter . _input_shapes
self . _is_shape_inferred = True
self . _inputs = self . _verify_spec ( None , self . _input_shapes , True )