|
|
|
@ -167,7 +167,7 @@ class BertAttentionMask(nn.Cell):
|
|
|
|
|
|
|
|
|
|
super(BertAttentionMask, self).__init__()
|
|
|
|
|
self.has_attention_mask = has_attention_mask
|
|
|
|
|
self.multiply_data = Tensor([-1000.0, ], dtype=dtype)
|
|
|
|
|
self.multiply_data = Tensor([-1000.0,], dtype=dtype)
|
|
|
|
|
self.multiply = P.Mul()
|
|
|
|
|
|
|
|
|
|
if self.has_attention_mask:
|
|
|
|
@ -198,7 +198,7 @@ class BertAttentionMaskBackward(nn.Cell):
|
|
|
|
|
dtype=mstype.float32):
|
|
|
|
|
super(BertAttentionMaskBackward, self).__init__()
|
|
|
|
|
self.has_attention_mask = has_attention_mask
|
|
|
|
|
self.multiply_data = Tensor([-1000.0, ], dtype=dtype)
|
|
|
|
|
self.multiply_data = Tensor([-1000.0,], dtype=dtype)
|
|
|
|
|
self.multiply = P.Mul()
|
|
|
|
|
self.attention_mask = Tensor(np.ones(shape=attention_mask_shape).astype(np.float32))
|
|
|
|
|
if self.has_attention_mask:
|
|
|
|
|