|
|
|
@ -277,8 +277,8 @@ class RelaPosMatrixGenerator(nn.Cell):
|
|
|
|
|
def __init__(self, length, max_relative_position):
|
|
|
|
|
super(RelaPosMatrixGenerator, self).__init__()
|
|
|
|
|
self._length = length
|
|
|
|
|
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
|
|
|
|
|
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
|
|
|
|
|
self._max_relative_position = max_relative_position
|
|
|
|
|
self._min_relative_position = -max_relative_position
|
|
|
|
|
self.range_length = -length + 1
|
|
|
|
|
|
|
|
|
|
self.tile = P.Tile()
|
|
|
|
@ -336,9 +336,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|
|
|
|
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
|
|
|
|
|
max_relative_position=max_relative_position)
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
|
self.one_hot = P.OneHot()
|
|
|
|
|
self.on_value = Tensor(1.0, mstype.float32)
|
|
|
|
|
self.off_value = Tensor(0.0, mstype.float32)
|
|
|
|
|
self.one_hot = nn.OneHot(depth=self.vocab_size)
|
|
|
|
|
self.shape = P.Shape()
|
|
|
|
|
self.gather = P.GatherV2() # index_select
|
|
|
|
|
self.matmul = P.BatchMatMul()
|
|
|
|
@ -350,7 +348,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|
|
|
|
if self.use_one_hot_embeddings:
|
|
|
|
|
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
|
|
|
|
|
one_hot_relative_positions_matrix = self.one_hot(
|
|
|
|
|
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
|
|
|
|
|
flat_relative_positions_matrix)
|
|
|
|
|
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
|
|
|
|
|
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
|
|
|
|
|
embeddings = self.reshape(embeddings, my_shape)
|
|
|
|
@ -372,11 +370,9 @@ class SaturateCast(nn.Cell):
|
|
|
|
|
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
|
|
|
|
|
super(SaturateCast, self).__init__()
|
|
|
|
|
np_type = mstype.dtype_to_nptype(dst_type)
|
|
|
|
|
min_type = np.finfo(np_type).min
|
|
|
|
|
max_type = np.finfo(np_type).max
|
|
|
|
|
|
|
|
|
|
self.tensor_min_type = Tensor([min_type], dtype=src_type)
|
|
|
|
|
self.tensor_max_type = Tensor([max_type], dtype=src_type)
|
|
|
|
|
self.tensor_min_type = float(np.finfo(np_type).min)
|
|
|
|
|
self.tensor_max_type = float(np.finfo(np_type).max)
|
|
|
|
|
|
|
|
|
|
self.min_op = P.Minimum()
|
|
|
|
|
self.max_op = P.Maximum()
|
|
|
|
@ -442,7 +438,7 @@ class BertAttention(nn.Cell):
|
|
|
|
|
self.has_attention_mask = has_attention_mask
|
|
|
|
|
self.use_relative_positions = use_relative_positions
|
|
|
|
|
|
|
|
|
|
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
|
|
|
|
|
self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
|
self.shape_from_2d = (-1, from_tensor_width)
|
|
|
|
|
self.shape_to_2d = (-1, to_tensor_width)
|
|
|
|
@ -471,7 +467,7 @@ class BertAttention(nn.Cell):
|
|
|
|
|
self.trans_shape = (0, 2, 1, 3)
|
|
|
|
|
self.trans_shape_relative = (2, 0, 1, 3)
|
|
|
|
|
self.trans_shape_position = (1, 2, 0, 3)
|
|
|
|
|
self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
|
|
|
|
|
self.multiply_data = -10000.0
|
|
|
|
|
self.batch_num = batch_size * num_attention_heads
|
|
|
|
|
self.matmul = P.BatchMatMul()
|
|
|
|
|
|
|
|
|
|