|
|
@ -361,6 +361,7 @@ class Block(nn.Cell):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(self, config, layer_idx):
|
|
|
|
def __init__(self, config, layer_idx):
|
|
|
|
super(Block, self).__init__()
|
|
|
|
super(Block, self).__init__()
|
|
|
|
|
|
|
|
scale = 1.0
|
|
|
|
self.layernorm1 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
|
|
|
|
self.layernorm1 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
|
|
|
|
self.attention = Attention(config, scale, layer_idx)
|
|
|
|
self.attention = Attention(config, scale, layer_idx)
|
|
|
|
self.layernorm2 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
|
|
|
|
self.layernorm2 = LayerNorm((config.embedding_size,)).to_float(config.compute_dtype)
|
|
|
|