|
|
|
@ -18,8 +18,10 @@ import numpy as np
|
|
|
|
|
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import paddle.fluid.layers as layers
|
|
|
|
|
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer
|
|
|
|
|
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable
|
|
|
|
|
from paddle.fluid.dygraph.jit import dygraph_to_static_func
|
|
|
|
|
from paddle.fluid.layers.utils import map_structure
|
|
|
|
|
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def position_encoding_init(n_position, d_pos_vec):
|
|
|
|
@ -486,3 +488,169 @@ class Transformer(Layer):
|
|
|
|
|
predict = self.decoder(trg_word, trg_pos, trg_slf_attn_bias,
|
|
|
|
|
trg_src_attn_bias, enc_output)
|
|
|
|
|
return predict
|
|
|
|
|
|
|
|
|
|
@dygraph_to_static_func
|
|
|
|
|
def beam_search(self,
|
|
|
|
|
src_word,
|
|
|
|
|
src_pos,
|
|
|
|
|
src_slf_attn_bias,
|
|
|
|
|
trg_word,
|
|
|
|
|
trg_src_attn_bias,
|
|
|
|
|
bos_id=0,
|
|
|
|
|
eos_id=1,
|
|
|
|
|
beam_size=4,
|
|
|
|
|
max_len=256):
|
|
|
|
|
def expand_to_beam_size(tensor, beam_size):
|
|
|
|
|
tensor = layers.reshape(
|
|
|
|
|
tensor, [tensor.shape[0], 1] + list(tensor.shape[1:]))
|
|
|
|
|
tile_dims = [1] * len(tensor.shape)
|
|
|
|
|
tile_dims[1] = beam_size
|
|
|
|
|
return layers.expand(tensor, tile_dims)
|
|
|
|
|
|
|
|
|
|
def merge_batch_beams(tensor):
|
|
|
|
|
var_dim_in_state = 2 # count in beam dim
|
|
|
|
|
tensor = layers.transpose(
|
|
|
|
|
tensor,
|
|
|
|
|
list(range(var_dim_in_state, len(tensor.shape))) +
|
|
|
|
|
list(range(0, var_dim_in_state)))
|
|
|
|
|
|
|
|
|
|
tensor = layers.reshape(tensor,
|
|
|
|
|
[0] * (len(tensor.shape) - var_dim_in_state
|
|
|
|
|
) + [batch_size * beam_size])
|
|
|
|
|
res = layers.transpose(
|
|
|
|
|
tensor,
|
|
|
|
|
list(
|
|
|
|
|
range((len(tensor.shape) + 1 - var_dim_in_state),
|
|
|
|
|
len(tensor.shape))) +
|
|
|
|
|
list(range(0, (len(tensor.shape) + 1 - var_dim_in_state))))
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
def split_batch_beams(tensor):
|
|
|
|
|
var_dim_in_state = 1
|
|
|
|
|
tensor = layers.transpose(
|
|
|
|
|
tensor,
|
|
|
|
|
list(range(var_dim_in_state, len(tensor.shape))) +
|
|
|
|
|
list(range(0, var_dim_in_state)))
|
|
|
|
|
tensor = layers.reshape(tensor,
|
|
|
|
|
[0] * (len(tensor.shape) - var_dim_in_state
|
|
|
|
|
) + [batch_size, beam_size])
|
|
|
|
|
res = layers.transpose(
|
|
|
|
|
tensor,
|
|
|
|
|
list(
|
|
|
|
|
range((len(tensor.shape) - 1 - var_dim_in_state),
|
|
|
|
|
len(tensor.shape))) +
|
|
|
|
|
list(range(0, (len(tensor.shape) - 1 - var_dim_in_state))))
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
def mask_probs(probs, finished, noend_mask_tensor):
|
|
|
|
|
finished = layers.cast(finished, dtype=probs.dtype)
|
|
|
|
|
probs = layers.elementwise_mul(
|
|
|
|
|
layers.expand(
|
|
|
|
|
layers.unsqueeze(finished, [2]),
|
|
|
|
|
[1, 1, self.trg_vocab_size]),
|
|
|
|
|
noend_mask_tensor,
|
|
|
|
|
axis=-1) - layers.elementwise_mul(
|
|
|
|
|
probs, (finished - 1), axis=0)
|
|
|
|
|
return probs
|
|
|
|
|
|
|
|
|
|
def gather(input, indices, batch_pos):
|
|
|
|
|
topk_coordinates = fluid.layers.stack([batch_pos, indices], axis=2)
|
|
|
|
|
return layers.gather_nd(input, topk_coordinates)
|
|
|
|
|
|
|
|
|
|
# run encoder
|
|
|
|
|
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
|
|
|
|
|
batch_size = enc_output.shape[0]
|
|
|
|
|
|
|
|
|
|
# constant number
|
|
|
|
|
inf = float(1. * 1e7)
|
|
|
|
|
max_len = (enc_output.shape[1] + 20) if max_len is None else max_len
|
|
|
|
|
vocab_size_tensor = layers.fill_constant(
|
|
|
|
|
shape=[1], dtype="int64", value=self.trg_vocab_size)
|
|
|
|
|
end_token_tensor = to_variable(
|
|
|
|
|
np.full(
|
|
|
|
|
[batch_size, beam_size], eos_id, dtype="int64"))
|
|
|
|
|
noend_array = [-inf] * self.trg_vocab_size
|
|
|
|
|
noend_array[eos_id] = 0
|
|
|
|
|
noend_mask_tensor = to_variable(np.array(noend_array, dtype="float32"))
|
|
|
|
|
batch_pos = layers.expand(
|
|
|
|
|
layers.unsqueeze(
|
|
|
|
|
to_variable(np.arange(
|
|
|
|
|
0, batch_size, 1, dtype="int64")), [1]), [1, beam_size])
|
|
|
|
|
predict_ids = []
|
|
|
|
|
parent_ids = []
|
|
|
|
|
### initialize states of beam search ###
|
|
|
|
|
log_probs = to_variable(
|
|
|
|
|
np.array(
|
|
|
|
|
[[0.] + [-inf] * (beam_size - 1)] * batch_size,
|
|
|
|
|
dtype="float32"))
|
|
|
|
|
|
|
|
|
|
finished = fluid.layers.fill_constant(
|
|
|
|
|
shape=[batch_size, beam_size], value=0, dtype="bool")
|
|
|
|
|
|
|
|
|
|
trg_word = layers.fill_constant(
|
|
|
|
|
shape=[batch_size * beam_size, 1], dtype="int64", value=bos_id)
|
|
|
|
|
|
|
|
|
|
trg_src_attn_bias = merge_batch_beams(
|
|
|
|
|
expand_to_beam_size(trg_src_attn_bias, beam_size))
|
|
|
|
|
enc_output = merge_batch_beams(
|
|
|
|
|
expand_to_beam_size(enc_output, beam_size))
|
|
|
|
|
|
|
|
|
|
# init states (caches) for transformer, need to be updated according to selected beam
|
|
|
|
|
caches = [{
|
|
|
|
|
"k": layers.fill_constant(
|
|
|
|
|
shape=[batch_size, beam_size, self.n_head, 0, self.d_key],
|
|
|
|
|
dtype=enc_output.dtype,
|
|
|
|
|
value=0),
|
|
|
|
|
"v": layers.fill_constant(
|
|
|
|
|
shape=[batch_size, beam_size, self.n_head, 0, self.d_value],
|
|
|
|
|
dtype=enc_output.dtype,
|
|
|
|
|
value=0),
|
|
|
|
|
} for i in range(self.n_layer)]
|
|
|
|
|
|
|
|
|
|
for i in range(max_len):
|
|
|
|
|
trg_pos = layers.zeros_like(
|
|
|
|
|
trg_word) + i # TODO: modified for dygraph2static
|
|
|
|
|
caches = map_structure(merge_batch_beams,
|
|
|
|
|
caches) # TODO: modified for dygraph2static
|
|
|
|
|
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
|
|
|
|
|
enc_output, caches)
|
|
|
|
|
caches = map_structure(split_batch_beams, caches)
|
|
|
|
|
step_log_probs = split_batch_beams(
|
|
|
|
|
fluid.layers.log(fluid.layers.softmax(logits)))
|
|
|
|
|
|
|
|
|
|
step_log_probs = mask_probs(step_log_probs, finished,
|
|
|
|
|
noend_mask_tensor)
|
|
|
|
|
log_probs = layers.elementwise_add(
|
|
|
|
|
x=step_log_probs, y=log_probs, axis=0)
|
|
|
|
|
log_probs = layers.reshape(log_probs,
|
|
|
|
|
[-1, beam_size * self.trg_vocab_size])
|
|
|
|
|
scores = log_probs
|
|
|
|
|
topk_scores, topk_indices = fluid.layers.topk(
|
|
|
|
|
input=scores, k=beam_size)
|
|
|
|
|
beam_indices = fluid.layers.elementwise_floordiv(topk_indices,
|
|
|
|
|
vocab_size_tensor)
|
|
|
|
|
token_indices = fluid.layers.elementwise_mod(topk_indices,
|
|
|
|
|
vocab_size_tensor)
|
|
|
|
|
|
|
|
|
|
# update states
|
|
|
|
|
caches = map_structure(lambda x: gather(x, beam_indices, batch_pos),
|
|
|
|
|
caches)
|
|
|
|
|
log_probs = gather(log_probs, topk_indices, batch_pos)
|
|
|
|
|
finished = gather(finished, beam_indices, batch_pos)
|
|
|
|
|
finished = layers.logical_or(
|
|
|
|
|
finished, layers.equal(token_indices, end_token_tensor))
|
|
|
|
|
trg_word = layers.reshape(token_indices, [-1, 1])
|
|
|
|
|
|
|
|
|
|
predict_ids.append(token_indices)
|
|
|
|
|
parent_ids.append(beam_indices)
|
|
|
|
|
|
|
|
|
|
if layers.reduce_all(finished).numpy():
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
predict_ids = layers.stack(predict_ids, axis=0)
|
|
|
|
|
parent_ids = layers.stack(parent_ids, axis=0)
|
|
|
|
|
finished_seq = layers.transpose(
|
|
|
|
|
layers.gather_tree(predict_ids, parent_ids), [1, 2, 0])
|
|
|
|
|
finished_scores = topk_scores
|
|
|
|
|
|
|
|
|
|
return finished_seq, finished_scores
|
|
|
|
|