|
|
|
@ -11,7 +11,6 @@
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import pdb
|
|
|
|
|
import layers
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
@ -163,7 +162,7 @@ def glu(input, dim=-1):
|
|
|
|
|
def scaled_dot_product_attention(queries,
|
|
|
|
|
keys,
|
|
|
|
|
values,
|
|
|
|
|
num_heads,
|
|
|
|
|
num_heads=1,
|
|
|
|
|
dropout_rate=0.):
|
|
|
|
|
"""
|
|
|
|
|
The dot-product attention.
|
|
|
|
@ -259,9 +258,12 @@ def scaled_dot_product_attention(queries,
|
|
|
|
|
raise ValueError("Input(x) should be a 4-D Tensor.")
|
|
|
|
|
|
|
|
|
|
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
|
|
|
|
|
return layers.reshape(x=layers.reshape(
|
|
|
|
|
return layers.reshape(
|
|
|
|
|
x=trans_x,
|
|
|
|
|
shape=[trans_x.shape[0], trans_x[1], trans_x[2] * trans_x[3]]))
|
|
|
|
|
shape=map(int, [
|
|
|
|
|
trans_x.shape[0], trans_x.shape[1],
|
|
|
|
|
trans_x.shape[2] * trans_x.shape[3]
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
q = __split_heads(queries, num_heads)
|
|
|
|
|
k = __split_heads(keys, num_heads)
|
|
|
|
@ -271,10 +273,11 @@ def scaled_dot_product_attention(queries,
|
|
|
|
|
scaled_q = layers.scale(x=q, scale=key_dim_per_head**-0.5)
|
|
|
|
|
product = layers.matmul(x=k, y=scaled_q, transpose_y=True)
|
|
|
|
|
|
|
|
|
|
attn_scores = layers.reshape(
|
|
|
|
|
weights = layers.reshape(
|
|
|
|
|
x=layers.reshape(
|
|
|
|
|
x=product, shape=[-1, product.shape[-1]], act="softmax"),
|
|
|
|
|
shape=product.shape)
|
|
|
|
|
ctx_multiheads = layers.matmul(attn_scores, v)
|
|
|
|
|
context = __combine_heads(ctx_multiheads)
|
|
|
|
|
return context
|
|
|
|
|
if dropout_rate:
|
|
|
|
|
weights = layers.dropout(x, dropout_prob=dropout_rate, is_test=False)
|
|
|
|
|
ctx_multiheads = layers.matmul(weights, v)
|
|
|
|
|
return __combine_heads(ctx_multiheads)
|
|
|
|
|