fix rnn check_type list error (#24346)

* fix rnn check_type list error

* tigger ci, test=develop

* update modify, test=develop
release/2.0-alpha
Xing Wu 5 years ago committed by GitHub
parent 63da846de0
commit 4af3ec0f8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -494,8 +494,14 @@ def rnn(cell,
if isinstance(initial_states, (list, tuple)): if isinstance(initial_states, (list, tuple)):
states = map_structure(lambda x: x, initial_states)[0] states = map_structure(lambda x: x, initial_states)[0]
for i, state in enumerate(states): for i, state in enumerate(states):
if isinstance(state, (list, tuple)):
for j, state_j in enumerate(state):
check_variable_and_dtype(state_j, 'state_j[' + str(j) + ']',
['float32', 'float64'], 'rnn')
else:
check_variable_and_dtype(state, 'states[' + str(i) + ']', check_variable_and_dtype(state, 'states[' + str(i) + ']',
['float32', 'float64'], 'rnn') ['float32', 'float64'], 'rnn')
check_type(sequence_length, 'sequence_length', (Variable, type(None)), check_type(sequence_length, 'sequence_length', (Variable, type(None)),
'rnn') 'rnn')

Loading…
Cancel
Save