|
|
@ -494,8 +494,14 @@ def rnn(cell,
|
|
|
|
if isinstance(initial_states, (list, tuple)):
|
|
|
|
if isinstance(initial_states, (list, tuple)):
|
|
|
|
states = map_structure(lambda x: x, initial_states)[0]
|
|
|
|
states = map_structure(lambda x: x, initial_states)[0]
|
|
|
|
for i, state in enumerate(states):
|
|
|
|
for i, state in enumerate(states):
|
|
|
|
|
|
|
|
if isinstance(state, (list, tuple)):
|
|
|
|
|
|
|
|
for j, state_j in enumerate(state):
|
|
|
|
|
|
|
|
check_variable_and_dtype(state_j, 'state_j[' + str(j) + ']',
|
|
|
|
|
|
|
|
['float32', 'float64'], 'rnn')
|
|
|
|
|
|
|
|
else:
|
|
|
|
check_variable_and_dtype(state, 'states[' + str(i) + ']',
|
|
|
|
check_variable_and_dtype(state, 'states[' + str(i) + ']',
|
|
|
|
['float32', 'float64'], 'rnn')
|
|
|
|
['float32', 'float64'], 'rnn')
|
|
|
|
|
|
|
|
|
|
|
|
check_type(sequence_length, 'sequence_length', (Variable, type(None)),
|
|
|
|
check_type(sequence_length, 'sequence_length', (Variable, type(None)),
|
|
|
|
'rnn')
|
|
|
|
'rnn')
|
|
|
|
|
|
|
|
|
|
|
|