@ -134,3 +134,19 @@ def test_check_str_by_regular():
_check_str_by_regular ( str5 )
with pytest . raises ( ValueError ) :
_check_str_by_regular ( str6 )
def test_parameter_lazy_init ( ) :
# Call init_data() without set default_input.
para = Parameter ( initializer ( ' ones ' , [ 1 , 2 , 3 ] , mstype . float32 ) , ' test1 ' )
assert not isinstance ( para . default_input , Tensor )
para . init_data ( )
assert isinstance ( para . default_input , Tensor )
assert np . array_equal ( para . default_input . asnumpy ( ) , np . ones ( ( 1 , 2 , 3 ) ) )
# Call init_data() after default_input is set.
para = Parameter ( initializer ( ' ones ' , [ 1 , 2 , 3 ] , mstype . float32 ) , ' test2 ' )
assert not isinstance ( para . default_input , Tensor )
para . default_input = Tensor ( np . zeros ( ( 1 , 2 , 3 ) ) )
assert np . array_equal ( para . default_input . asnumpy ( ) , np . zeros ( ( 1 , 2 , 3 ) ) )
para . init_data ( ) # expect no effect.
assert np . array_equal ( para . default_input . asnumpy ( ) , np . zeros ( ( 1 , 2 , 3 ) ) )