You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
994 lines
48 KiB
994 lines
48 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Bert model."""
|
|
import math
|
|
import copy
|
|
import numpy as np
|
|
import mindspore.common.dtype as mstype
|
|
import mindspore.nn as nn
|
|
import mindspore.ops.functional as F
|
|
from mindspore.common.initializer import TruncatedNormal, initializer
|
|
from mindspore.ops import operations as P
|
|
from mindspore.ops import composite as C
|
|
from mindspore.common.tensor import Tensor
|
|
from mindspore.common.parameter import Parameter
|
|
from mindspore import context
|
|
|
|
|
|
class BertConfig:
|
|
"""
|
|
Configuration for `BertModel`.
|
|
|
|
Args:
|
|
seq_length (int): Length of input sequence. Default: 128.
|
|
vocab_size (int): The shape of each embedding vector. Default: 32000.
|
|
hidden_size (int): Size of the bert encoder layers. Default: 768.
|
|
num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
|
|
cell. Default: 12.
|
|
num_attention_heads (int): Number of attention heads in the BertTransformer
|
|
encoder cell. Default: 12.
|
|
intermediate_size (int): Size of intermediate layer in the BertTransformer
|
|
encoder cell. Default: 3072.
|
|
hidden_act (str): Activation function used in the BertTransformer encoder
|
|
cell. Default: "gelu".
|
|
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
|
attention_probs_dropout_prob (float): The dropout probability for
|
|
BertAttention. Default: 0.1.
|
|
max_position_embeddings (int): Maximum length of sequences used in this
|
|
model. Default: 512.
|
|
type_vocab_size (int): Size of token type vocab. Default: 16.
|
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
|
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
|
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
|
|
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
|
"""
|
|
def __init__(self,
|
|
seq_length=128,
|
|
vocab_size=32000,
|
|
hidden_size=768,
|
|
num_hidden_layers=12,
|
|
num_attention_heads=12,
|
|
intermediate_size=3072,
|
|
hidden_act="gelu",
|
|
hidden_dropout_prob=0.1,
|
|
attention_probs_dropout_prob=0.1,
|
|
max_position_embeddings=512,
|
|
type_vocab_size=16,
|
|
initializer_range=0.02,
|
|
use_relative_positions=False,
|
|
dtype=mstype.float32,
|
|
compute_type=mstype.float32):
|
|
self.seq_length = seq_length
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.hidden_act = hidden_act
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.type_vocab_size = type_vocab_size
|
|
self.initializer_range = initializer_range
|
|
self.use_relative_positions = use_relative_positions
|
|
self.dtype = dtype
|
|
self.compute_type = compute_type
|
|
|
|
|
|
class EmbeddingLookup(nn.Cell):
|
|
"""
|
|
A embeddings lookup table with a fixed dictionary and size.
|
|
|
|
Args:
|
|
vocab_size (int): Size of the dictionary of embeddings.
|
|
embedding_size (int): The size of each embedding vector.
|
|
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
|
each embedding vector.
|
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
|
"""
|
|
def __init__(self,
|
|
vocab_size,
|
|
embedding_size,
|
|
embedding_shape,
|
|
use_one_hot_embeddings=False,
|
|
initializer_range=0.02):
|
|
super(EmbeddingLookup, self).__init__()
|
|
self.vocab_size = vocab_size
|
|
self.use_one_hot_embeddings = use_one_hot_embeddings
|
|
self.embedding_table = Parameter(initializer
|
|
(TruncatedNormal(initializer_range),
|
|
[vocab_size, embedding_size]),
|
|
name='embedding_table')
|
|
self.expand = P.ExpandDims()
|
|
self.shape_flat = (-1,)
|
|
self.gather = P.GatherV2()
|
|
self.one_hot = P.OneHot()
|
|
self.on_value = Tensor(1.0, mstype.float32)
|
|
self.off_value = Tensor(0.0, mstype.float32)
|
|
self.array_mul = P.MatMul()
|
|
self.reshape = P.Reshape()
|
|
self.shape = tuple(embedding_shape)
|
|
|
|
def construct(self, input_ids):
|
|
"""embedding lookup"""
|
|
extended_ids = self.expand(input_ids, -1)
|
|
flat_ids = self.reshape(extended_ids, self.shape_flat)
|
|
if self.use_one_hot_embeddings:
|
|
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
|
output_for_reshape = self.array_mul(
|
|
one_hot_ids, self.embedding_table)
|
|
else:
|
|
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
|
output = self.reshape(output_for_reshape, self.shape)
|
|
return output, self.embedding_table
|
|
|
|
|
|
class EmbeddingPostprocessor(nn.Cell):
|
|
"""
|
|
Postprocessors apply positional and token type embeddings to word embeddings.
|
|
|
|
Args:
|
|
embedding_size (int): The size of each embedding vector.
|
|
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
|
each embedding vector.
|
|
use_token_type (bool): Specifies whether to use token type embeddings. Default: False.
|
|
token_type_vocab_size (int): Size of token type vocab. Default: 16.
|
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
|
max_position_embeddings (int): Maximum length of sequences used in this
|
|
model. Default: 512.
|
|
dropout_prob (float): The dropout probability. Default: 0.1.
|
|
"""
|
|
def __init__(self,
|
|
use_relative_positions,
|
|
embedding_size,
|
|
embedding_shape,
|
|
use_token_type=False,
|
|
token_type_vocab_size=16,
|
|
use_one_hot_embeddings=False,
|
|
initializer_range=0.02,
|
|
max_position_embeddings=512,
|
|
dropout_prob=0.1):
|
|
super(EmbeddingPostprocessor, self).__init__()
|
|
self.use_token_type = use_token_type
|
|
self.token_type_vocab_size = token_type_vocab_size
|
|
self.use_one_hot_embeddings = use_one_hot_embeddings
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.embedding_table = Parameter(initializer
|
|
(TruncatedNormal(initializer_range),
|
|
[token_type_vocab_size,
|
|
embedding_size]),
|
|
name='embedding_table')
|
|
self.shape_flat = (-1,)
|
|
self.one_hot = P.OneHot()
|
|
self.on_value = Tensor(1.0, mstype.float32)
|
|
self.off_value = Tensor(0.1, mstype.float32)
|
|
self.array_mul = P.MatMul()
|
|
self.reshape = P.Reshape()
|
|
self.shape = tuple(embedding_shape)
|
|
self.layernorm = nn.LayerNorm((embedding_size,))
|
|
self.dropout = nn.Dropout(1 - dropout_prob)
|
|
self.gather = P.GatherV2()
|
|
self.use_relative_positions = use_relative_positions
|
|
self.slice = P.StridedSlice()
|
|
self.full_position_embeddings = Parameter(initializer
|
|
(TruncatedNormal(initializer_range),
|
|
[max_position_embeddings,
|
|
embedding_size]),
|
|
name='full_position_embeddings')
|
|
|
|
def construct(self, token_type_ids, word_embeddings):
|
|
"""embedding postprocessor"""
|
|
output = word_embeddings
|
|
if self.use_token_type:
|
|
flat_ids = self.reshape(token_type_ids, self.shape_flat)
|
|
if self.use_one_hot_embeddings:
|
|
one_hot_ids = self.one_hot(flat_ids,
|
|
self.token_type_vocab_size, self.on_value, self.off_value)
|
|
token_type_embeddings = self.array_mul(one_hot_ids,
|
|
self.embedding_table)
|
|
else:
|
|
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
|
|
token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
|
|
output += token_type_embeddings
|
|
if not self.use_relative_positions:
|
|
_, seq, width = self.shape
|
|
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
|
|
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
|
|
output += position_embeddings
|
|
output = self.layernorm(output)
|
|
output = self.dropout(output)
|
|
return output
|
|
|
|
|
|
class BertOutput(nn.Cell):
|
|
"""
|
|
Apply a linear computation to hidden status and a residual computation to input.
|
|
|
|
Args:
|
|
in_channels (int): Input channels.
|
|
out_channels (int): Output channels.
|
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
|
dropout_prob (float): The dropout probability. Default: 0.1.
|
|
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
|
"""
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
initializer_range=0.02,
|
|
dropout_prob=0.1,
|
|
compute_type=mstype.float32):
|
|
super(BertOutput, self).__init__()
|
|
self.dense = nn.Dense(in_channels, out_channels,
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
|
self.dropout = nn.Dropout(1 - dropout_prob)
|
|
self.add = P.TensorAdd()
|
|
self.is_gpu = context.get_context('device_target') == "GPU"
|
|
if self.is_gpu:
|
|
self.layernorm = nn.LayerNorm((out_channels,)).to_float(mstype.float32)
|
|
self.compute_type = compute_type
|
|
else:
|
|
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
|
|
|
self.cast = P.Cast()
|
|
|
|
def construct(self, hidden_status, input_tensor):
|
|
"""bert output"""
|
|
output = self.dense(hidden_status)
|
|
output = self.dropout(output)
|
|
output = self.add(input_tensor, output)
|
|
output = self.layernorm(output)
|
|
if self.is_gpu:
|
|
output = self.cast(output, self.compute_type)
|
|
return output
|
|
|
|
|
|
class RelaPosMatrixGenerator(nn.Cell):
|
|
"""
|
|
Generates matrix of relative positions between inputs.
|
|
|
|
Args:
|
|
length (int): Length of one dim for the matrix to be generated.
|
|
max_relative_position (int): Max value of relative position.
|
|
"""
|
|
def __init__(self, length, max_relative_position):
|
|
super(RelaPosMatrixGenerator, self).__init__()
|
|
self._length = length
|
|
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
|
|
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
|
|
self.range_length = -length + 1
|
|
self.tile = P.Tile()
|
|
self.range_mat = P.Reshape()
|
|
self.sub = P.Sub()
|
|
self.expanddims = P.ExpandDims()
|
|
self.cast = P.Cast()
|
|
|
|
def construct(self):
|
|
"""position matrix generator"""
|
|
range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
|
|
range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
|
|
tile_row_out = self.tile(range_vec_row_out, (self._length,))
|
|
tile_col_out = self.tile(range_vec_col_out, (1, self._length))
|
|
range_mat_out = self.range_mat(tile_row_out, (self._length, self._length))
|
|
transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
|
|
distance_mat = self.sub(range_mat_out, transpose_out)
|
|
distance_mat_clipped = C.clip_by_value(distance_mat,
|
|
self._min_relative_position,
|
|
self._max_relative_position)
|
|
# Shift values to be >=0. Each integer still uniquely identifies a
|
|
# relative position difference.
|
|
final_mat = distance_mat_clipped + self._max_relative_position
|
|
return final_mat
|
|
|
|
|
|
class RelaPosEmbeddingsGenerator(nn.Cell):
|
|
"""
|
|
Generates tensor of size [length, length, depth].
|
|
|
|
Args:
|
|
length (int): Length of one dim for the matrix to be generated.
|
|
depth (int): Size of each attention head.
|
|
max_relative_position (int): Maxmum value of relative position.
|
|
initializer_range (float): Initialization value of TruncatedNormal.
|
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
|
"""
|
|
def __init__(self,
|
|
length,
|
|
depth,
|
|
max_relative_position,
|
|
initializer_range,
|
|
use_one_hot_embeddings=False):
|
|
super(RelaPosEmbeddingsGenerator, self).__init__()
|
|
self.depth = depth
|
|
self.vocab_size = max_relative_position * 2 + 1
|
|
self.use_one_hot_embeddings = use_one_hot_embeddings
|
|
self.embeddings_table = Parameter(
|
|
initializer(TruncatedNormal(initializer_range),
|
|
[self.vocab_size, self.depth]),
|
|
name='embeddings_for_position')
|
|
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
|
|
max_relative_position=max_relative_position)
|
|
self.reshape = P.Reshape()
|
|
self.one_hot = P.OneHot()
|
|
self.on_value = Tensor(1.0, mstype.float32)
|
|
self.off_value = Tensor(0.0, mstype.float32)
|
|
self.shape = P.Shape()
|
|
self.gather = P.GatherV2() # index_select
|
|
self.matmul = P.BatchMatMul()
|
|
|
|
def construct(self):
|
|
"""position embedding generation"""
|
|
relative_positions_matrix_out = self.relative_positions_matrix()
|
|
# Generate embedding for each relative position of dimension depth.
|
|
if self.use_one_hot_embeddings:
|
|
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
|
|
one_hot_relative_positions_matrix = self.one_hot(
|
|
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
|
|
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
|
|
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
|
|
embeddings = self.reshape(embeddings, my_shape)
|
|
else:
|
|
embeddings = self.gather(self.embeddings_table,
|
|
relative_positions_matrix_out, 0)
|
|
return embeddings
|
|
|
|
|
|
class SaturateCast(nn.Cell):
|
|
"""
|
|
Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
|
|
the danger that the value will overflow or underflow.
|
|
|
|
Args:
|
|
src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
|
|
dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
|
|
"""
|
|
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
|
|
super(SaturateCast, self).__init__()
|
|
np_type = mstype.dtype_to_nptype(dst_type)
|
|
min_type = np.finfo(np_type).min
|
|
max_type = np.finfo(np_type).max
|
|
self.tensor_min_type = Tensor([min_type], dtype=src_type)
|
|
self.tensor_max_type = Tensor([max_type], dtype=src_type)
|
|
self.min_op = P.Minimum()
|
|
self.max_op = P.Maximum()
|
|
self.cast = P.Cast()
|
|
self.dst_type = dst_type
|
|
|
|
def construct(self, x):
|
|
"""saturate cast"""
|
|
out = self.max_op(x, self.tensor_min_type)
|
|
out = self.min_op(out, self.tensor_max_type)
|
|
return self.cast(out, self.dst_type)
|
|
|
|
|
|
class BertAttention(nn.Cell):
|
|
"""
|
|
Apply multi-headed attention from "from_tensor" to "to_tensor".
|
|
|
|
Args:
|
|
from_tensor_width (int): Size of last dim of from_tensor.
|
|
to_tensor_width (int): Size of last dim of to_tensor.
|
|
from_seq_length (int): Length of from_tensor sequence.
|
|
to_seq_length (int): Length of to_tensor sequence.
|
|
num_attention_heads (int): Number of attention heads. Default: 1.
|
|
size_per_head (int): Size of each attention head. Default: 512.
|
|
query_act (str): Activation function for the query transform. Default: None.
|
|
key_act (str): Activation function for the key transform. Default: None.
|
|
value_act (str): Activation function for the value transform. Default: None.
|
|
has_attention_mask (bool): Specifies whether to use attention mask. Default: False.
|
|
attention_probs_dropout_prob (float): The dropout probability for
|
|
BertAttention. Default: 0.0.
|
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
|
do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
|
|
tensor. Default: False.
|
|
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
|
compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
|
|
"""
|
|
def __init__(self,
|
|
from_tensor_width,
|
|
to_tensor_width,
|
|
from_seq_length,
|
|
to_seq_length,
|
|
num_attention_heads=1,
|
|
size_per_head=512,
|
|
query_act=None,
|
|
key_act=None,
|
|
value_act=None,
|
|
has_attention_mask=False,
|
|
attention_probs_dropout_prob=0.0,
|
|
use_one_hot_embeddings=False,
|
|
initializer_range=0.02,
|
|
do_return_2d_tensor=False,
|
|
use_relative_positions=False,
|
|
compute_type=mstype.float32):
|
|
super(BertAttention, self).__init__()
|
|
self.from_seq_length = from_seq_length
|
|
self.to_seq_length = to_seq_length
|
|
self.num_attention_heads = num_attention_heads
|
|
self.size_per_head = size_per_head
|
|
self.has_attention_mask = has_attention_mask
|
|
self.use_relative_positions = use_relative_positions
|
|
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
|
|
self.reshape = P.Reshape()
|
|
self.shape_from_2d = (-1, from_tensor_width)
|
|
self.shape_to_2d = (-1, to_tensor_width)
|
|
weight = TruncatedNormal(initializer_range)
|
|
units = num_attention_heads * size_per_head
|
|
self.query_layer = nn.Dense(from_tensor_width,
|
|
units,
|
|
activation=query_act,
|
|
weight_init=weight).to_float(compute_type)
|
|
self.key_layer = nn.Dense(to_tensor_width,
|
|
units,
|
|
activation=key_act,
|
|
weight_init=weight).to_float(compute_type)
|
|
self.value_layer = nn.Dense(to_tensor_width,
|
|
units,
|
|
activation=value_act,
|
|
weight_init=weight).to_float(compute_type)
|
|
self.shape_from = (-1, from_seq_length, num_attention_heads, size_per_head)
|
|
self.shape_to = (-1, to_seq_length, num_attention_heads, size_per_head)
|
|
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
|
|
self.multiply = P.Mul()
|
|
self.transpose = P.Transpose()
|
|
self.trans_shape = (0, 2, 1, 3)
|
|
self.trans_shape_relative = (2, 0, 1, 3)
|
|
self.trans_shape_position = (1, 2, 0, 3)
|
|
self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
|
|
self.matmul = P.BatchMatMul()
|
|
self.softmax = nn.Softmax()
|
|
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
|
|
if self.has_attention_mask:
|
|
self.expand_dims = P.ExpandDims()
|
|
self.sub = P.Sub()
|
|
self.add = P.TensorAdd()
|
|
self.cast = P.Cast()
|
|
self.get_dtype = P.DType()
|
|
if do_return_2d_tensor:
|
|
self.shape_return = (-1, num_attention_heads * size_per_head)
|
|
else:
|
|
self.shape_return = (-1, from_seq_length, num_attention_heads * size_per_head)
|
|
self.cast_compute_type = SaturateCast(dst_type=compute_type)
|
|
if self.use_relative_positions:
|
|
self._generate_relative_positions_embeddings = \
|
|
RelaPosEmbeddingsGenerator(length=to_seq_length,
|
|
depth=size_per_head,
|
|
max_relative_position=16,
|
|
initializer_range=initializer_range,
|
|
use_one_hot_embeddings=use_one_hot_embeddings)
|
|
|
|
def construct(self, from_tensor, to_tensor, attention_mask):
|
|
"""bert attention"""
|
|
# reshape 2d/3d input tensors to 2d
|
|
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
|
|
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
|
|
query_out = self.query_layer(from_tensor_2d)
|
|
key_out = self.key_layer(to_tensor_2d)
|
|
value_out = self.value_layer(to_tensor_2d)
|
|
query_layer = self.reshape(query_out, self.shape_from)
|
|
query_layer = self.transpose(query_layer, self.trans_shape)
|
|
key_layer = self.reshape(key_out, self.shape_to)
|
|
key_layer = self.transpose(key_layer, self.trans_shape)
|
|
attention_scores = self.matmul_trans_b(query_layer, key_layer)
|
|
# use_relative_position, supplementary logic
|
|
if self.use_relative_positions:
|
|
# relations_keys is [F|T, F|T, H]
|
|
relations_keys = self._generate_relative_positions_embeddings()
|
|
relations_keys = self.cast_compute_type(relations_keys)
|
|
# query_layer_t is [F, B, N, H]
|
|
query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
|
|
# query_layer_r is [F, B * N, H]
|
|
query_layer_r = self.reshape(query_layer_t,
|
|
(self.from_seq_length,
|
|
-1,
|
|
self.size_per_head))
|
|
# key_position_scores is [F, B * N, F|T]
|
|
key_position_scores = self.matmul_trans_b(query_layer_r,
|
|
relations_keys)
|
|
# key_position_scores_r is [F, B, N, F|T]
|
|
key_position_scores_r = self.reshape(key_position_scores,
|
|
(self.from_seq_length,
|
|
-1,
|
|
self.num_attention_heads,
|
|
self.from_seq_length))
|
|
# key_position_scores_r_t is [B, N, F, F|T]
|
|
key_position_scores_r_t = self.transpose(key_position_scores_r,
|
|
self.trans_shape_position)
|
|
attention_scores = attention_scores + key_position_scores_r_t
|
|
attention_scores = self.multiply(self.scores_mul, attention_scores)
|
|
if self.has_attention_mask:
|
|
attention_mask = self.expand_dims(attention_mask, 1)
|
|
multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
|
|
self.cast(attention_mask, self.get_dtype(attention_scores)))
|
|
adder = self.multiply(multiply_out, self.multiply_data)
|
|
attention_scores = self.add(adder, attention_scores)
|
|
attention_probs = self.softmax(attention_scores)
|
|
attention_probs = self.dropout(attention_probs)
|
|
value_layer = self.reshape(value_out, self.shape_to)
|
|
value_layer = self.transpose(value_layer, self.trans_shape)
|
|
context_layer = self.matmul(attention_probs, value_layer)
|
|
# use_relative_position, supplementary logic
|
|
if self.use_relative_positions:
|
|
# relations_values is [F|T, F|T, H]
|
|
relations_values = self._generate_relative_positions_embeddings()
|
|
relations_values = self.cast_compute_type(relations_values)
|
|
# attention_probs_t is [F, B, N, T]
|
|
attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
|
|
# attention_probs_r is [F, B * N, T]
|
|
attention_probs_r = self.reshape(
|
|
attention_probs_t,
|
|
(self.from_seq_length,
|
|
-1,
|
|
self.to_seq_length))
|
|
# value_position_scores is [F, B * N, H]
|
|
value_position_scores = self.matmul(attention_probs_r,
|
|
relations_values)
|
|
# value_position_scores_r is [F, B, N, H]
|
|
value_position_scores_r = self.reshape(value_position_scores,
|
|
(self.from_seq_length,
|
|
-1,
|
|
self.num_attention_heads,
|
|
self.size_per_head))
|
|
# value_position_scores_r_t is [B, N, F, H]
|
|
value_position_scores_r_t = self.transpose(value_position_scores_r,
|
|
self.trans_shape_position)
|
|
context_layer = context_layer + value_position_scores_r_t
|
|
context_layer = self.transpose(context_layer, self.trans_shape)
|
|
context_layer = self.reshape(context_layer, self.shape_return)
|
|
return context_layer, attention_scores
|
|
|
|
class BertSelfAttention(nn.Cell):
|
|
"""
|
|
Apply self-attention.
|
|
|
|
Args:
|
|
seq_length (int): Length of input sequence.
|
|
hidden_size (int): Size of the bert encoder layers.
|
|
num_attention_heads (int): Number of attention heads. Default: 12.
|
|
attention_probs_dropout_prob (float): The dropout probability for
|
|
BertAttention. Default: 0.1.
|
|
use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
|
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
|
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
|
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
|
compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
|
|
"""
|
|
def __init__(self,
|
|
seq_length,
|
|
hidden_size,
|
|
num_attention_heads=12,
|
|
attention_probs_dropout_prob=0.1,
|
|
use_one_hot_embeddings=False,
|
|
initializer_range=0.02,
|
|
hidden_dropout_prob=0.1,
|
|
use_relative_positions=False,
|
|
compute_type=mstype.float32):
|
|
super(BertSelfAttention, self).__init__()
|
|
if hidden_size % num_attention_heads != 0:
|
|
raise ValueError("The hidden size (%d) is not a multiple of the number "
|
|
"of attention heads (%d)" % (hidden_size, num_attention_heads))
|
|
self.size_per_head = int(hidden_size / num_attention_heads)
|
|
self.attention = BertAttention(
|
|
from_tensor_width=hidden_size,
|
|
to_tensor_width=hidden_size,
|
|
from_seq_length=seq_length,
|
|
to_seq_length=seq_length,
|
|
num_attention_heads=num_attention_heads,
|
|
size_per_head=self.size_per_head,
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
initializer_range=initializer_range,
|
|
use_relative_positions=use_relative_positions,
|
|
has_attention_mask=True,
|
|
do_return_2d_tensor=True,
|
|
compute_type=compute_type)
|
|
self.output = BertOutput(in_channels=hidden_size,
|
|
out_channels=hidden_size,
|
|
initializer_range=initializer_range,
|
|
dropout_prob=hidden_dropout_prob,
|
|
compute_type=compute_type)
|
|
self.reshape = P.Reshape()
|
|
self.shape = (-1, hidden_size)
|
|
|
|
def construct(self, input_tensor, attention_mask):
|
|
"""bert self attention"""
|
|
input_tensor = self.reshape(input_tensor, self.shape)
|
|
attention_output, attention_scores = self.attention(input_tensor, input_tensor, attention_mask)
|
|
output = self.output(attention_output, input_tensor)
|
|
return output, attention_scores
|
|
|
|
|
|
class BertEncoderCell(nn.Cell):
|
|
"""
|
|
Encoder cells used in BertTransformer.
|
|
|
|
Args:
|
|
hidden_size (int): Size of the bert encoder layers. Default: 768.
|
|
seq_length (int): Length of input sequence. Default: 512.
|
|
num_attention_heads (int): Number of attention heads. Default: 12.
|
|
intermediate_size (int): Size of intermediate layer. Default: 3072.
|
|
attention_probs_dropout_prob (float): The dropout probability for
|
|
BertAttention. Default: 0.02.
|
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
|
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
|
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
|
hidden_act (str): Activation function. Default: "gelu".
|
|
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
|
|
"""
|
|
def __init__(self,
|
|
hidden_size=768,
|
|
seq_length=512,
|
|
num_attention_heads=12,
|
|
intermediate_size=3072,
|
|
attention_probs_dropout_prob=0.02,
|
|
use_one_hot_embeddings=False,
|
|
initializer_range=0.02,
|
|
hidden_dropout_prob=0.1,
|
|
use_relative_positions=False,
|
|
hidden_act="gelu",
|
|
compute_type=mstype.float32):
|
|
super(BertEncoderCell, self).__init__()
|
|
self.attention = BertSelfAttention(
|
|
hidden_size=hidden_size,
|
|
seq_length=seq_length,
|
|
num_attention_heads=num_attention_heads,
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
initializer_range=initializer_range,
|
|
hidden_dropout_prob=hidden_dropout_prob,
|
|
use_relative_positions=use_relative_positions,
|
|
compute_type=compute_type)
|
|
self.intermediate = nn.Dense(in_channels=hidden_size,
|
|
out_channels=intermediate_size,
|
|
activation=hidden_act,
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
|
self.output = BertOutput(in_channels=intermediate_size,
|
|
out_channels=hidden_size,
|
|
initializer_range=initializer_range,
|
|
dropout_prob=hidden_dropout_prob,
|
|
compute_type=compute_type)
|
|
def construct(self, hidden_states, attention_mask):
|
|
"""bert encoder cell"""
|
|
# self-attention
|
|
attention_output, attention_scores = self.attention(hidden_states, attention_mask)
|
|
# feed construct
|
|
intermediate_output = self.intermediate(attention_output)
|
|
# add and normalize
|
|
output = self.output(intermediate_output, attention_output)
|
|
return output, attention_scores
|
|
|
|
|
|
class BertTransformer(nn.Cell):
|
|
"""
|
|
Multi-layer bert transformer.
|
|
|
|
Args:
|
|
hidden_size (int): Size of the encoder layers.
|
|
seq_length (int): Length of input sequence.
|
|
num_hidden_layers (int): Number of hidden layers in encoder cells.
|
|
num_attention_heads (int): Number of attention heads in encoder cells. Default: 12.
|
|
intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072.
|
|
attention_probs_dropout_prob (float): The dropout probability for
|
|
BertAttention. Default: 0.1.
|
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
|
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
|
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
|
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
|
|
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
|
return_all_encoders (bool): Specifies whether to return all encoders. Default: False.
|
|
"""
|
|
def __init__(self,
|
|
hidden_size,
|
|
seq_length,
|
|
num_hidden_layers,
|
|
num_attention_heads=12,
|
|
intermediate_size=3072,
|
|
attention_probs_dropout_prob=0.1,
|
|
use_one_hot_embeddings=False,
|
|
initializer_range=0.02,
|
|
hidden_dropout_prob=0.1,
|
|
use_relative_positions=False,
|
|
hidden_act="gelu",
|
|
compute_type=mstype.float32,
|
|
return_all_encoders=False):
|
|
super(BertTransformer, self).__init__()
|
|
self.return_all_encoders = return_all_encoders
|
|
layers = []
|
|
for _ in range(num_hidden_layers):
|
|
layer = BertEncoderCell(hidden_size=hidden_size,
|
|
seq_length=seq_length,
|
|
num_attention_heads=num_attention_heads,
|
|
intermediate_size=intermediate_size,
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
initializer_range=initializer_range,
|
|
hidden_dropout_prob=hidden_dropout_prob,
|
|
use_relative_positions=use_relative_positions,
|
|
hidden_act=hidden_act,
|
|
compute_type=compute_type)
|
|
layers.append(layer)
|
|
self.layers = nn.CellList(layers)
|
|
self.reshape = P.Reshape()
|
|
self.shape = (-1, hidden_size)
|
|
self.out_shape = (-1, seq_length, hidden_size)
|
|
def construct(self, input_tensor, attention_mask):
|
|
"""bert transformer"""
|
|
prev_output = self.reshape(input_tensor, self.shape)
|
|
all_encoder_layers = ()
|
|
all_encoder_atts = ()
|
|
all_encoder_outputs = ()
|
|
all_encoder_outputs += (prev_output,)
|
|
for layer_module in self.layers:
|
|
layer_output, encoder_att = layer_module(prev_output, attention_mask)
|
|
prev_output = layer_output
|
|
if self.return_all_encoders:
|
|
all_encoder_outputs += (layer_output,)
|
|
layer_output = self.reshape(layer_output, self.out_shape)
|
|
all_encoder_layers += (layer_output,)
|
|
all_encoder_atts += (encoder_att,)
|
|
if not self.return_all_encoders:
|
|
prev_output = self.reshape(prev_output, self.out_shape)
|
|
all_encoder_layers += (prev_output,)
|
|
return all_encoder_layers, all_encoder_outputs, all_encoder_atts
|
|
|
|
|
|
class CreateAttentionMaskFromInputMask(nn.Cell):
|
|
"""
|
|
Create attention mask according to input mask.
|
|
|
|
Args:
|
|
config (Class): Configuration for BertModel.
|
|
"""
|
|
def __init__(self, config):
|
|
super(CreateAttentionMaskFromInputMask, self).__init__()
|
|
self.input_mask = None
|
|
self.cast = P.Cast()
|
|
self.reshape = P.Reshape()
|
|
self.shape = (-1, 1, config.seq_length)
|
|
|
|
def construct(self, input_mask):
|
|
attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
|
|
return attention_mask
|
|
|
|
class BertModel(nn.Cell):
|
|
"""
|
|
Bidirectional Encoder Representations from Transformers.
|
|
|
|
Args:
|
|
config (Class): Configuration for BertModel.
|
|
is_training (bool): True for training mode. False for eval mode.
|
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
|
"""
|
|
def __init__(self,
|
|
config,
|
|
is_training,
|
|
use_one_hot_embeddings=False):
|
|
super(BertModel, self).__init__()
|
|
config = copy.deepcopy(config)
|
|
if not is_training:
|
|
config.hidden_dropout_prob = 0.0
|
|
config.attention_probs_dropout_prob = 0.0
|
|
self.seq_length = config.seq_length
|
|
self.hidden_size = config.hidden_size
|
|
self.num_hidden_layers = config.num_hidden_layers
|
|
self.embedding_size = config.hidden_size
|
|
self.token_type_ids = None
|
|
self.last_idx = self.num_hidden_layers - 1
|
|
output_embedding_shape = [-1, self.seq_length,
|
|
self.embedding_size]
|
|
self.bert_embedding_lookup = EmbeddingLookup(
|
|
vocab_size=config.vocab_size,
|
|
embedding_size=self.embedding_size,
|
|
embedding_shape=output_embedding_shape,
|
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
initializer_range=config.initializer_range)
|
|
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
|
|
use_relative_positions=config.use_relative_positions,
|
|
embedding_size=self.embedding_size,
|
|
embedding_shape=output_embedding_shape,
|
|
use_token_type=True,
|
|
token_type_vocab_size=config.type_vocab_size,
|
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
initializer_range=0.02,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
dropout_prob=config.hidden_dropout_prob)
|
|
self.bert_encoder = BertTransformer(
|
|
hidden_size=self.hidden_size,
|
|
seq_length=self.seq_length,
|
|
num_attention_heads=config.num_attention_heads,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
intermediate_size=config.intermediate_size,
|
|
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
|
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
initializer_range=config.initializer_range,
|
|
hidden_dropout_prob=config.hidden_dropout_prob,
|
|
use_relative_positions=config.use_relative_positions,
|
|
hidden_act=config.hidden_act,
|
|
compute_type=config.compute_type,
|
|
return_all_encoders=True)
|
|
self.cast = P.Cast()
|
|
self.dtype = config.dtype
|
|
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
|
|
self.slice = P.StridedSlice()
|
|
self.squeeze_1 = P.Squeeze(axis=1)
|
|
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
|
|
activation="tanh",
|
|
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
|
|
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
|
|
|
|
def construct(self, input_ids, token_type_ids, input_mask):
|
|
"""bert model"""
|
|
# embedding
|
|
word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids)
|
|
embedding_output = self.bert_embedding_postprocessor(token_type_ids, word_embeddings)
|
|
# attention mask [batch_size, seq_length, seq_length]
|
|
attention_mask = self._create_attention_mask_from_input_mask(input_mask)
|
|
# bert encoder
|
|
encoder_output, encoder_layers, layer_atts = self.bert_encoder(self.cast_compute_type(embedding_output),
|
|
attention_mask)
|
|
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
|
|
# pooler
|
|
sequence_slice = self.slice(sequence_output,
|
|
(0, 0, 0),
|
|
(-1, 1, self.hidden_size),
|
|
(1, 1, 1))
|
|
first_token = self.squeeze_1(sequence_slice)
|
|
pooled_output = self.dense(first_token)
|
|
pooled_output = self.cast(pooled_output, self.dtype)
|
|
encoder_outputs = ()
|
|
for output in encoder_layers:
|
|
encoder_outputs += (self.cast(output, self.dtype),)
|
|
attention_outputs = ()
|
|
for output in layer_atts:
|
|
attention_outputs += (self.cast(output, self.dtype),)
|
|
return sequence_output, pooled_output, embedding_tables, encoder_outputs, attention_outputs
|
|
|
|
|
|
class TinyBertModel(nn.Cell):
|
|
"""
|
|
Bidirectional Encoder Representations from Transformers.
|
|
|
|
Args:
|
|
config (Class): Configuration for BertModel.
|
|
is_training (bool): True for training mode. False for eval mode.
|
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
|
"""
|
|
def __init__(self,
|
|
config,
|
|
is_training,
|
|
use_one_hot_embeddings=False):
|
|
super(TinyBertModel, self).__init__()
|
|
config = copy.deepcopy(config)
|
|
if not is_training:
|
|
config.hidden_dropout_prob = 0.0
|
|
config.attention_probs_dropout_prob = 0.0
|
|
self.seq_length = config.seq_length
|
|
self.hidden_size = config.hidden_size
|
|
self.num_hidden_layers = config.num_hidden_layers
|
|
self.embedding_size = config.hidden_size
|
|
self.token_type_ids = None
|
|
self.last_idx = self.num_hidden_layers - 1
|
|
output_embedding_shape = [-1, self.seq_length,
|
|
self.embedding_size]
|
|
self.tinybert_embedding_lookup = EmbeddingLookup(
|
|
vocab_size=config.vocab_size,
|
|
embedding_size=self.embedding_size,
|
|
embedding_shape=output_embedding_shape,
|
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
initializer_range=config.initializer_range)
|
|
self.tinybert_embedding_postprocessor = EmbeddingPostprocessor(
|
|
use_relative_positions=config.use_relative_positions,
|
|
embedding_size=self.embedding_size,
|
|
embedding_shape=output_embedding_shape,
|
|
use_token_type=True,
|
|
token_type_vocab_size=config.type_vocab_size,
|
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
initializer_range=0.02,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
dropout_prob=config.hidden_dropout_prob)
|
|
self.tinybert_encoder = BertTransformer(
|
|
hidden_size=self.hidden_size,
|
|
seq_length=self.seq_length,
|
|
num_attention_heads=config.num_attention_heads,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
intermediate_size=config.intermediate_size,
|
|
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
|
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
|
initializer_range=config.initializer_range,
|
|
hidden_dropout_prob=config.hidden_dropout_prob,
|
|
use_relative_positions=config.use_relative_positions,
|
|
hidden_act=config.hidden_act,
|
|
compute_type=config.compute_type,
|
|
return_all_encoders=True)
|
|
self.cast = P.Cast()
|
|
self.dtype = config.dtype
|
|
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
|
|
self.slice = P.StridedSlice()
|
|
self.squeeze_1 = P.Squeeze(axis=1)
|
|
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
|
|
activation="tanh",
|
|
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
|
|
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
|
|
|
|
def construct(self, input_ids, token_type_ids, input_mask):
|
|
"""tiny bert model"""
|
|
# embedding
|
|
word_embeddings, embedding_tables = self.tinybert_embedding_lookup(input_ids)
|
|
embedding_output = self.tinybert_embedding_postprocessor(token_type_ids,
|
|
word_embeddings)
|
|
# attention mask [batch_size, seq_length, seq_length]
|
|
attention_mask = self._create_attention_mask_from_input_mask(input_mask)
|
|
# bert encoder
|
|
encoder_output, encoder_layers, layer_atts = self.tinybert_encoder(self.cast_compute_type(embedding_output),
|
|
attention_mask)
|
|
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
|
|
# pooler
|
|
sequence_slice = self.slice(sequence_output,
|
|
(0, 0, 0),
|
|
(-1, 1, self.hidden_size),
|
|
(1, 1, 1))
|
|
first_token = self.squeeze_1(sequence_slice)
|
|
pooled_output = self.dense(first_token)
|
|
pooled_output = self.cast(pooled_output, self.dtype)
|
|
encoder_outputs = ()
|
|
for output in encoder_layers:
|
|
encoder_outputs += (self.cast(output, self.dtype),)
|
|
attention_outputs = ()
|
|
for output in layer_atts:
|
|
attention_outputs += (self.cast(output, self.dtype),)
|
|
return sequence_output, pooled_output, embedding_tables, encoder_outputs, attention_outputs
|
|
|
|
|
|
class BertModelCLS(nn.Cell):
|
|
"""
|
|
This class is responsible for classification task evaluation,
|
|
i.e. XNLI(num_labels=3), LCQMC(num_labels=2), Chnsenti(num_labels=2).
|
|
The returned output represents the final logits as the results of log_softmax is propotional to that of softmax.
|
|
"""
|
|
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0,
|
|
use_one_hot_embeddings=False, phase_type="teacher"):
|
|
super(BertModelCLS, self).__init__()
|
|
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
|
self.cast = P.Cast()
|
|
self.weight_init = TruncatedNormal(config.initializer_range)
|
|
self.log_softmax = P.LogSoftmax(axis=-1)
|
|
self.dtype = config.dtype
|
|
self.num_labels = num_labels
|
|
self.phase_type = phase_type
|
|
if self.phase_type == "teacher":
|
|
self.dense = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
|
|
has_bias=True).to_float(config.compute_type)
|
|
else:
|
|
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
|
|
has_bias=True).to_float(config.compute_type)
|
|
self.dropout = nn.ReLU()
|
|
|
|
def construct(self, input_ids, token_type_id, input_mask):
|
|
"""classification bert model"""
|
|
_, pooled_output, _, seq_output, att_output = self.bert(input_ids, token_type_id, input_mask)
|
|
cls = self.cast(pooled_output, self.dtype)
|
|
cls = self.dropout(cls)
|
|
if self.phase_type == "teacher":
|
|
logits = self.dense(cls)
|
|
else:
|
|
logits = self.dense_1(cls)
|
|
logits = self.cast(logits, self.dtype)
|
|
log_probs = self.log_softmax(logits)
|
|
return seq_output, att_output, logits, log_probs
|