|
|
|
@ -123,6 +123,8 @@ class AttentionPredict(object):
|
|
|
|
|
|
|
|
|
|
full_ids = fluid.layers.fill_constant_batch_size_like(
|
|
|
|
|
input=init_state, shape=[-1, 1], dtype='int64', value=1)
|
|
|
|
|
full_scores = fluid.layers.fill_constant_batch_size_like(
|
|
|
|
|
input=init_state, shape=[-1, 1], dtype='float32', value=1)
|
|
|
|
|
|
|
|
|
|
cond = layers.less_than(x=counter, y=array_len)
|
|
|
|
|
while_op = layers.While(cond=cond)
|
|
|
|
@ -171,6 +173,9 @@ class AttentionPredict(object):
|
|
|
|
|
new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1)
|
|
|
|
|
fluid.layers.assign(new_ids, full_ids)
|
|
|
|
|
|
|
|
|
|
new_scores = fluid.layers.concat([full_scores, topk_scores], axis=1)
|
|
|
|
|
fluid.layers.assign(new_scores, full_scores)
|
|
|
|
|
|
|
|
|
|
layers.increment(x=counter, value=1, in_place=True)
|
|
|
|
|
|
|
|
|
|
# update the memories
|
|
|
|
@ -184,7 +189,7 @@ class AttentionPredict(object):
|
|
|
|
|
length_cond = layers.less_than(x=counter, y=array_len)
|
|
|
|
|
finish_cond = layers.logical_not(layers.is_empty(x=topk_indices))
|
|
|
|
|
layers.logical_and(x=length_cond, y=finish_cond, out=cond)
|
|
|
|
|
return full_ids
|
|
|
|
|
return full_ids, full_scores
|
|
|
|
|
|
|
|
|
|
def __call__(self, inputs, labels=None, mode=None):
|
|
|
|
|
encoder_features = self.encoder(inputs)
|
|
|
|
@ -223,10 +228,10 @@ class AttentionPredict(object):
|
|
|
|
|
decoder_size, char_num)
|
|
|
|
|
_, decoded_out = layers.topk(input=predict, k=1)
|
|
|
|
|
decoded_out = layers.lod_reset(decoded_out, y=label_out)
|
|
|
|
|
predicts = {'predict': predict, 'decoded_out': decoded_out}
|
|
|
|
|
predicts = {'predict':predict, 'decoded_out':decoded_out}
|
|
|
|
|
else:
|
|
|
|
|
ids = self.gru_attention_infer(
|
|
|
|
|
ids, predict = self.gru_attention_infer(
|
|
|
|
|
decoder_boot, self.max_length, char_num, word_vector_dim,
|
|
|
|
|
encoded_vector, encoded_proj, decoder_size)
|
|
|
|
|
predicts = {'decoded_out': ids}
|
|
|
|
|
predicts = {'predict':predict, 'decoded_out':ids}
|
|
|
|
|
return predicts
|
|
|
|
|