|
|
|
@ -22,8 +22,10 @@ from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class textrcnn(nn.Cell):
|
|
|
|
|
"""class textrcnn"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, weight, vocab_size, cell, batch_size):
|
|
|
|
|
super(textrcnn, self).__init__()
|
|
|
|
|
self.num_hiddens = 512
|
|
|
|
@ -89,7 +91,6 @@ class textrcnn(nn.Cell):
|
|
|
|
|
self.tanh = P.Tanh()
|
|
|
|
|
self.sigmoid = P.Sigmoid()
|
|
|
|
|
self.slice = P.Slice()
|
|
|
|
|
# self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,has_bias=has_bias, batch_first=batch_first, bidirectional=bidirectional, dropout=0.0)
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
"""class construction"""
|
|
|
|
@ -100,34 +101,34 @@ class textrcnn(nn.Cell):
|
|
|
|
|
if self.cell == "vanilla":
|
|
|
|
|
x = self.embedding(x) # bs, sl, emb_size
|
|
|
|
|
x = self.cast(x, mstype.float16)
|
|
|
|
|
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
|
|
|
|
x = self.drop_out(x) # sl,bs, emb_size
|
|
|
|
|
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
|
|
|
|
x = self.drop_out(x) # sl,bs, emb_size
|
|
|
|
|
|
|
|
|
|
h1_fw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
|
|
|
|
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[0, :, :])) # bs, num_hidden
|
|
|
|
|
output_fw = self.expand_dims(h1_fw, 0) # 1, bs, num_hidden
|
|
|
|
|
h1_fw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
|
|
|
|
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[0, :, :])) # bs, num_hidden
|
|
|
|
|
output_fw = self.expand_dims(h1_fw, 0) # 1, bs, num_hidden
|
|
|
|
|
|
|
|
|
|
for i in range(1, F.shape(x)[0]):
|
|
|
|
|
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[i, :, :])) # 1, bs, num_hidden
|
|
|
|
|
h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[i, :, :])) # 1, bs, num_hidden
|
|
|
|
|
h1_after_expand_fw = self.expand_dims(h1_fw, 0)
|
|
|
|
|
output_fw = self.concat((output_fw, h1_after_expand_fw)) # 2/3/4.., bs, num_hidden
|
|
|
|
|
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
|
|
|
|
|
output_fw = self.concat((output_fw, h1_after_expand_fw)) # 2/3/4.., bs, num_hidden
|
|
|
|
|
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
|
|
|
|
|
|
|
|
|
|
h1_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
|
|
|
|
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[F.shape(x)[0] - 1, :, :])) # bs, num_hidden
|
|
|
|
|
output_bw = self.expand_dims(h1_bw, 0) # 1, bs, num_hidden
|
|
|
|
|
h1_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
|
|
|
|
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[F.shape(x)[0] - 1, :, :])) # bs, num_hidden
|
|
|
|
|
output_bw = self.expand_dims(h1_bw, 0) # 1, bs, num_hidden
|
|
|
|
|
|
|
|
|
|
for i in range(F.shape(x)[0] - 2, -1, -1):
|
|
|
|
|
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[i, :, :])) # 1, bs, num_hidden
|
|
|
|
|
h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[i, :, :])) # 1, bs, num_hidden
|
|
|
|
|
h1_after_expand_bw = self.expand_dims(h1_bw, 0)
|
|
|
|
|
output_bw = self.concat((h1_after_expand_bw, output_bw)) # 2/3/4.., bs, num_hidden
|
|
|
|
|
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, num_hidden
|
|
|
|
|
output_bw = self.concat((h1_after_expand_bw, output_bw)) # 2/3/4.., bs, num_hidden
|
|
|
|
|
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, num_hidden
|
|
|
|
|
|
|
|
|
|
if self.cell == "gru":
|
|
|
|
|
x = self.embedding(x) # bs, sl, emb_size
|
|
|
|
|
x = self.cast(x, mstype.float16)
|
|
|
|
|
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
|
|
|
|
x = self.drop_out(x) # sl,bs, emb_size
|
|
|
|
|
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
|
|
|
|
x = self.drop_out(x) # sl,bs, emb_size
|
|
|
|
|
|
|
|
|
|
h_fw = self.cast(self.h1, mstype.float16)
|
|
|
|
|
|
|
|
|
@ -148,7 +149,7 @@ class textrcnn(nn.Cell):
|
|
|
|
|
output_fw = self.concat((output_fw, h_after_expand_fw))
|
|
|
|
|
output_fw = self.cast(output_fw, mstype.float16)
|
|
|
|
|
|
|
|
|
|
h_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
|
|
|
|
h_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
|
|
|
|
|
|
|
|
|
|
h_x_bw = self.concat1((h_bw, x[F.shape(x)[0] - 1, :, :]))
|
|
|
|
|
r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
|
|
|
|
@ -168,29 +169,29 @@ class textrcnn(nn.Cell):
|
|
|
|
|
if self.cell == 'lstm':
|
|
|
|
|
x = self.embedding(x) # bs, sl, emb_size
|
|
|
|
|
x = self.cast(x, mstype.float16)
|
|
|
|
|
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
|
|
|
|
x = self.drop_out(x) # sl,bs, emb_size
|
|
|
|
|
x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
|
|
|
|
|
x = self.drop_out(x) # sl,bs, emb_size
|
|
|
|
|
|
|
|
|
|
h1_fw_init = self.h1 # bs, num_hidden
|
|
|
|
|
c1_fw_init = self.c1 # bs, num_hidden
|
|
|
|
|
h1_fw_init = self.h1 # bs, num_hidden
|
|
|
|
|
c1_fw_init = self.c1 # bs, num_hidden
|
|
|
|
|
|
|
|
|
|
_, output_fw, _, _, _, _, _, _ = self.lstm(x, self.w1_fw, self.b1_fw, None, h1_fw_init, c1_fw_init)
|
|
|
|
|
output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
|
|
|
|
|
|
|
|
|
|
h1_bw_init = self.h1 # bs, num_hidden
|
|
|
|
|
c1_bw_init = self.c1 # bs, num_hidden
|
|
|
|
|
h1_bw_init = self.h1 # bs, num_hidden
|
|
|
|
|
c1_bw_init = self.c1 # bs, num_hidden
|
|
|
|
|
_, output_bw, _, _, _, _, _, _ = self.lstm(x, self.w1_bw, self.b1_bw, None, h1_bw_init, c1_bw_init)
|
|
|
|
|
output_bw = self.cast(output_bw, mstype.float16) # sl, bs, hidden
|
|
|
|
|
|
|
|
|
|
c_left = self.concat0((self.left_pad_tensor, output_fw[:F.shape(x)[0] - 1])) # sl, bs, num_hidden
|
|
|
|
|
c_right = self.concat0((output_bw[1:], self.right_pad_tensor)) # sl, bs, num_hidden
|
|
|
|
|
output = self.concat2((c_left, self.cast(x, mstype.float16), c_right)) # sl, bs, 2*num_hidden+emb_size
|
|
|
|
|
output = self.concat2((c_left, self.cast(x, mstype.float16), c_right)) # sl, bs, 2*num_hidden+emb_size
|
|
|
|
|
output = self.cast(output, mstype.float16)
|
|
|
|
|
|
|
|
|
|
output_flat = self.reshape(output, (F.shape(x)[0] * self.batch_size, 2 * self.num_hiddens + self.embed_size))
|
|
|
|
|
output_dense = self.text_rep_dense(output_flat) # sl*bs, num_hidden
|
|
|
|
|
output_dense = self.tanh(output_dense) # sl*bs, num_hidden
|
|
|
|
|
output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden
|
|
|
|
|
output = self.reduce_max(output, 0) # bs, num_hidden
|
|
|
|
|
outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes
|
|
|
|
|
output_dense = self.text_rep_dense(output_flat) # sl*bs, num_hidden
|
|
|
|
|
output_dense = self.tanh(output_dense) # sl*bs, num_hidden
|
|
|
|
|
output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden
|
|
|
|
|
output = self.reduce_max(output, 0) # bs, num_hidden
|
|
|
|
|
outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes
|
|
|
|
|
return outputs
|
|
|
|
|