|
|
|
@ -166,7 +166,7 @@ class BertAttentionMask(nn.Cell):
|
|
|
|
|
|
|
|
|
|
super(BertAttentionMask, self).__init__()
|
|
|
|
|
self.has_attention_mask = has_attention_mask
|
|
|
|
|
self.multiply_data = Tensor([-1000.0,], dtype=dtype)
|
|
|
|
|
self.multiply_data = Tensor([-1000.0, ], dtype=dtype)
|
|
|
|
|
self.multiply = P.Mul()
|
|
|
|
|
|
|
|
|
|
if self.has_attention_mask:
|
|
|
|
@ -189,6 +189,7 @@ class BertAttentionMask(nn.Cell):
|
|
|
|
|
|
|
|
|
|
return attention_scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BertAttentionMaskBackward(nn.Cell):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
attention_mask_shape,
|
|
|
|
@ -196,7 +197,7 @@ class BertAttentionMaskBackward(nn.Cell):
|
|
|
|
|
dtype=mstype.float32):
|
|
|
|
|
super(BertAttentionMaskBackward, self).__init__()
|
|
|
|
|
self.has_attention_mask = has_attention_mask
|
|
|
|
|
self.multiply_data = Tensor([-1000.0,], dtype=dtype)
|
|
|
|
|
self.multiply_data = Tensor([-1000.0, ], dtype=dtype)
|
|
|
|
|
self.multiply = P.Mul()
|
|
|
|
|
self.attention_mask = Tensor(np.ones(shape=attention_mask_shape).astype(np.float32))
|
|
|
|
|
if self.has_attention_mask:
|
|
|
|
@ -218,6 +219,7 @@ class BertAttentionMaskBackward(nn.Cell):
|
|
|
|
|
attention_scores = self.add(adder, attention_scores)
|
|
|
|
|
return attention_scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BertAttentionSoftmax(nn.Cell):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
batch_size,
|
|
|
|
|