|
|
|
@ -64,7 +64,7 @@ class PrePostProcessLayer(Layer):
|
|
|
|
|
elif cmd == "d": # add dropout
|
|
|
|
|
if dropout_rate:
|
|
|
|
|
self.functors.append(lambda x: layers.dropout(
|
|
|
|
|
x, dropout_prob=dropout_rate, is_test=False))
|
|
|
|
|
x, dropout_prob=dropout_rate))
|
|
|
|
|
|
|
|
|
|
def forward(self, x, residual=None):
|
|
|
|
|
for i, cmd in enumerate(self.process_cmd):
|
|
|
|
@ -137,8 +137,7 @@ class MultiHeadAttention(Layer):
|
|
|
|
|
product += attn_bias
|
|
|
|
|
weights = layers.softmax(product)
|
|
|
|
|
if self.dropout_rate:
|
|
|
|
|
weights = layers.dropout(
|
|
|
|
|
weights, dropout_prob=self.dropout_rate, is_test=False)
|
|
|
|
|
weights = layers.dropout(weights, dropout_prob=self.dropout_rate)
|
|
|
|
|
out = layers.matmul(weights, v)
|
|
|
|
|
out = layers.transpose(out, perm=[0, 2, 1, 3])
|
|
|
|
|
out = layers.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
|
|
|
|
@ -156,8 +155,7 @@ class FFN(Layer):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
hidden = self.fc1(x)
|
|
|
|
|
if self.dropout_rate:
|
|
|
|
|
hidden = layers.dropout(
|
|
|
|
|
hidden, dropout_prob=self.dropout_rate, is_test=False)
|
|
|
|
|
hidden = layers.dropout(hidden, dropout_prob=self.dropout_rate)
|
|
|
|
|
out = self.fc2(hidden)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
@ -276,8 +274,8 @@ class WrapEncoder(Layer):
|
|
|
|
|
pos_enc.stop_gradient = True
|
|
|
|
|
emb = word_emb + pos_enc
|
|
|
|
|
enc_input = layers.dropout(
|
|
|
|
|
emb, dropout_prob=self.emb_dropout,
|
|
|
|
|
is_test=False) if self.emb_dropout else emb
|
|
|
|
|
emb,
|
|
|
|
|
dropout_prob=self.emb_dropout, ) if self.emb_dropout else emb
|
|
|
|
|
enc_output = self.encoder(enc_input, src_slf_attn_bias)
|
|
|
|
|
return enc_output
|
|
|
|
|
|
|
|
|
@ -407,8 +405,8 @@ class WrapDecoder(Layer):
|
|
|
|
|
pos_enc.stop_gradient = True
|
|
|
|
|
emb = word_emb + pos_enc
|
|
|
|
|
dec_input = layers.dropout(
|
|
|
|
|
emb, dropout_prob=self.emb_dropout,
|
|
|
|
|
is_test=False) if self.emb_dropout else emb
|
|
|
|
|
emb,
|
|
|
|
|
dropout_prob=self.emb_dropout, ) if self.emb_dropout else emb
|
|
|
|
|
dec_output = self.decoder(dec_input, enc_output, trg_slf_attn_bias,
|
|
|
|
|
trg_src_attn_bias, caches)
|
|
|
|
|
dec_output = layers.reshape(
|
|
|
|
|