|
|
|
@ -191,7 +191,7 @@ class StateCell(object):
|
|
|
|
|
self._helper = LayerHelper('state_cell', name=name)
|
|
|
|
|
self._cur_states = {}
|
|
|
|
|
self._state_names = []
|
|
|
|
|
for state_name, state in states.items():
|
|
|
|
|
for state_name, state in six.iteritems(states):
|
|
|
|
|
if not isinstance(state, InitState):
|
|
|
|
|
raise ValueError('state must be an InitState object.')
|
|
|
|
|
self._cur_states[state_name] = state
|
|
|
|
@ -346,7 +346,7 @@ class StateCell(object):
|
|
|
|
|
if self._in_decoder and not self._switched_decoder:
|
|
|
|
|
self._switch_decoder()
|
|
|
|
|
|
|
|
|
|
for input_name, input_value in inputs.items():
|
|
|
|
|
for input_name, input_value in six.iteritems(inputs):
|
|
|
|
|
if input_name not in self._inputs:
|
|
|
|
|
raise ValueError('Unknown input %s. '
|
|
|
|
|
'Please make sure %s in input '
|
|
|
|
@ -361,7 +361,7 @@ class StateCell(object):
|
|
|
|
|
if self._in_decoder and not self._switched_decoder:
|
|
|
|
|
self._switched_decoder()
|
|
|
|
|
|
|
|
|
|
for state_name, decoder_state in self._states_holder.items():
|
|
|
|
|
for state_name, decoder_state in six.iteritems(self._states_holder):
|
|
|
|
|
if id(self._cur_decoder_obj) not in decoder_state:
|
|
|
|
|
raise ValueError('Unknown decoder object, please make sure '
|
|
|
|
|
'switch_decoder been invoked.')
|
|
|
|
@ -671,7 +671,7 @@ class BeamSearchDecoder(object):
|
|
|
|
|
feed_dict = {}
|
|
|
|
|
update_dict = {}
|
|
|
|
|
|
|
|
|
|
for init_var_name, init_var in self._input_var_dict.items():
|
|
|
|
|
for init_var_name, init_var in six.iteritems(self._input_var_dict):
|
|
|
|
|
if init_var_name not in self.state_cell._inputs:
|
|
|
|
|
raise ValueError('Variable ' + init_var_name +
|
|
|
|
|
' not found in StateCell!\n')
|
|
|
|
@ -721,7 +721,8 @@ class BeamSearchDecoder(object):
|
|
|
|
|
self.state_cell.update_states()
|
|
|
|
|
self.update_array(prev_ids, selected_ids)
|
|
|
|
|
self.update_array(prev_scores, selected_scores)
|
|
|
|
|
for update_name, var_to_update in update_dict.items():
|
|
|
|
|
for update_name, var_to_update in six.iteritems(
|
|
|
|
|
update_dict):
|
|
|
|
|
self.update_array(var_to_update, feed_dict[update_name])
|
|
|
|
|
|
|
|
|
|
def read_array(self, init, is_ids=False, is_scores=False):
|
|
|
|
|