pull/14617/head
l00591931 4 years ago
parent 7092868a6c
commit 0c4b0b7914

@ -415,7 +415,7 @@ class GELU(Cell):
def __init__(self): def __init__(self):
super(GELU, self).__init__() super(GELU, self).__init__()
self.gelu = _selected_ops.Gelu() self.gelu = _selected_ops.GeLU()
def construct(self, x): def construct(self, x):
return self.gelu(x) return self.gelu(x)
@ -458,7 +458,7 @@ class FastGelu(Cell):
def __init__(self): def __init__(self):
super(FastGelu, self).__init__() super(FastGelu, self).__init__()
self.fast_gelu = _selected_ops.FastGelu() self.fast_gelu = _selected_ops.FastGeLU()
def construct(self, x): def construct(self, x):
return self.fast_gelu(x) return self.fast_gelu(x)

@ -73,13 +73,13 @@ class Tanh:
@op_selector @op_selector
class Gelu: class GeLU:
def __call__(self, *args): def __call__(self, *args):
pass pass
@op_selector @op_selector
class FastGelu: class FastGeLU:
def __call__(self, *args): def __call__(self, *args):
pass pass

@ -499,7 +499,7 @@ class FeedForward(nn.Cell):
self.layernorm = LayerNorm(in_channels=in_channels) self.layernorm = LayerNorm(in_channels=in_channels)
self.residual_connect = ResidualConnection(dropout_prob=hidden_dropout) self.residual_connect = ResidualConnection(dropout_prob=hidden_dropout)
self.gelu_act = P.Gelu() self.gelu_act = P.GeLU()
self.dropout = nn.Dropout(1 - hidden_dropout) self.dropout = nn.Dropout(1 - hidden_dropout)
self.use_dropout = hidden_dropout > 0 self.use_dropout = hidden_dropout > 0
self.reshape = P.Reshape() self.reshape = P.Reshape()

Loading…
Cancel
Save