@ -277,8 +277,8 @@ class RelaPosMatrixGenerator(nn.Cell):
def __init__ ( self , length , max_relative_position ) :
super ( RelaPosMatrixGenerator , self ) . __init__ ( )
self . _length = length
self . _max_relative_position = Tensor( max_relative_position, dtype = mstype . int32 )
self . _min_relative_position = Tensor ( - max_relative_position , dtype = mstype . int32 )
self . _max_relative_position = max_relative_position
self . _min_relative_position = - max_relative_position
self . range_length = - length + 1
self . tile = P . Tile ( )
@ -336,9 +336,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
self . relative_positions_matrix = RelaPosMatrixGenerator ( length = length ,
max_relative_position = max_relative_position )
self . reshape = P . Reshape ( )
self . one_hot = P . OneHot ( )
self . on_value = Tensor ( 1.0 , mstype . float32 )
self . off_value = Tensor ( 0.0 , mstype . float32 )
self . one_hot = nn . OneHot ( depth = self . vocab_size )
self . shape = P . Shape ( )
self . gather = P . GatherV2 ( ) # index_select
self . matmul = P . BatchMatMul ( )
@ -350,7 +348,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
if self . use_one_hot_embeddings :
flat_relative_positions_matrix = self . reshape ( relative_positions_matrix_out , ( - 1 , ) )
one_hot_relative_positions_matrix = self . one_hot (
flat_relative_positions_matrix , self . vocab_size , self . on_value , self . off_value )
flat_relative_positions_matrix )
embeddings = self . matmul ( one_hot_relative_positions_matrix , self . embeddings_table )
my_shape = self . shape ( relative_positions_matrix_out ) + ( self . depth , )
embeddings = self . reshape ( embeddings , my_shape )
@ -372,11 +370,11 @@ class SaturateCast(nn.Cell):
def __init__ ( self , src_type = mstype . float32 , dst_type = mstype . float32 ) :
super ( SaturateCast , self ) . __init__ ( )
np_type = mstype . dtype_to_nptype ( dst_type )
min_type = np . finfo ( np_type ) . min
max_type = np . finfo ( np_type ) . max
min_type = float ( np . finfo ( np_type ) . min )
max_type = float ( np . finfo ( np_type ) . max )
self . tensor_min_type = Tensor( [ min_type] , dtype = src_type )
self . tensor_max_type = Tensor( [ max_type] , dtype = src_type )
self . tensor_min_type = min_type
self . tensor_max_type = max_type
self . min_op = P . Minimum ( )
self . max_op = P . Maximum ( )
@ -442,7 +440,7 @@ class BertAttention(nn.Cell):
self . has_attention_mask = has_attention_mask
self . use_relative_positions = use_relative_positions
self . scores_mul = Tensor ( [ 1.0 / math . sqrt ( float ( self . size_per_head ) ) ] , dtype = compute_type )
self . scores_mul = 1.0 / math . sqrt ( float ( self . size_per_head ) )
self . reshape = P . Reshape ( )
self . shape_from_2d = ( - 1 , from_tensor_width )
self . shape_to_2d = ( - 1 , to_tensor_width )
@ -471,7 +469,7 @@ class BertAttention(nn.Cell):
self . trans_shape = ( 0 , 2 , 1 , 3 )
self . trans_shape_relative = ( 2 , 0 , 1 , 3 )
self . trans_shape_position = ( 1 , 2 , 0 , 3 )
self . multiply_data = Tensor ( [ - 10000.0 , ] , dtype = compute_type )
self . multiply_data = - 10000.0
self . batch_num = batch_size * num_attention_heads
self . matmul = P . BatchMatMul ( )