|
|
|
@ -858,11 +858,12 @@ class RNNBase(LayerList):
|
|
|
|
|
bias_ih_attr=None,
|
|
|
|
|
bias_hh_attr=None):
|
|
|
|
|
super(RNNBase, self).__init__()
|
|
|
|
|
bidirectional_list = ["bidirectional", "bidirect"]
|
|
|
|
|
self.mode = mode
|
|
|
|
|
self.input_size = input_size
|
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
|
self.dropout = dropout
|
|
|
|
|
self.num_directions = 2 if direction == "bidirectional" else 1
|
|
|
|
|
self.num_directions = 2 if direction in bidirectional_list else 1
|
|
|
|
|
self.time_major = time_major
|
|
|
|
|
self.num_layers = num_layers
|
|
|
|
|
self.state_components = 2 if mode == "LSTM" else 1
|
|
|
|
@ -882,14 +883,14 @@ class RNNBase(LayerList):
|
|
|
|
|
rnn_cls = SimpleRNNCell
|
|
|
|
|
kwargs["activation"] = self.activation
|
|
|
|
|
|
|
|
|
|
if direction in ["forward", "backward"]:
|
|
|
|
|
is_reverse = direction == "backward"
|
|
|
|
|
if direction in ["forward"]:
|
|
|
|
|
is_reverse = False
|
|
|
|
|
cell = rnn_cls(input_size, hidden_size, **kwargs)
|
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
|
|
for i in range(1, num_layers):
|
|
|
|
|
cell = rnn_cls(hidden_size, hidden_size, **kwargs)
|
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
|
|
elif direction == "bidirectional":
|
|
|
|
|
elif direction in bidirectional_list:
|
|
|
|
|
cell_fw = rnn_cls(input_size, hidden_size, **kwargs)
|
|
|
|
|
cell_bw = rnn_cls(input_size, hidden_size, **kwargs)
|
|
|
|
|
self.append(BiRNN(cell_fw, cell_bw, time_major))
|
|
|
|
@ -899,13 +900,12 @@ class RNNBase(LayerList):
|
|
|
|
|
self.append(BiRNN(cell_fw, cell_bw, time_major))
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"direction should be forward, backward or bidirectional, "
|
|
|
|
|
"direction should be forward or bidirect (or bidirectional), "
|
|
|
|
|
"received direction = {}".format(direction))
|
|
|
|
|
|
|
|
|
|
self.could_use_cudnn = True
|
|
|
|
|
self.could_use_cudnn &= direction != "backward"
|
|
|
|
|
self.could_use_cudnn &= len(self.parameters()) == num_layers * 4 * (
|
|
|
|
|
2 if direction == "bidirectional" else 1)
|
|
|
|
|
2 if direction in bidirectional_list else 1)
|
|
|
|
|
|
|
|
|
|
# Expose params as RNN's attribute, which can make it compatible when
|
|
|
|
|
# replacing small ops composed rnn with cpp rnn kernel.
|
|
|
|
@ -1079,8 +1079,8 @@ class SimpleRNN(RNNBase):
|
|
|
|
|
input_size (int): The input size for the first layer's cell.
|
|
|
|
|
hidden_size (int): The hidden size for each layer's cell.
|
|
|
|
|
num_layers (int, optional): Number of layers. Defaults to 1.
|
|
|
|
|
direction (str, optional): The direction of the network. It can be "forward",
|
|
|
|
|
"backward" and "bidirectional". When "bidirectional", the way to merge
|
|
|
|
|
direction (str, optional): The direction of the network. It can be "forward"
|
|
|
|
|
or "bidirect"(or "bidirectional"). When "bidirect", the way to merge
|
|
|
|
|
outputs of forward and backward is concatenating. Defaults to "forward".
|
|
|
|
|
time_major (bool, optional): Whether the first dimension of the input means the
|
|
|
|
|
time steps. Defaults to False.
|
|
|
|
@ -1195,8 +1195,8 @@ class LSTM(RNNBase):
|
|
|
|
|
input_size (int): The input size for the first layer's cell.
|
|
|
|
|
hidden_size (int): The hidden size for each layer's cell.
|
|
|
|
|
num_layers (int, optional): Number of layers. Defaults to 1.
|
|
|
|
|
direction (str, optional): The direction of the network. It can be "forward",
|
|
|
|
|
"backward" and "bidirectional". When "bidirectional", the way to merge
|
|
|
|
|
direction (str, optional): The direction of the network. It can be "forward"
|
|
|
|
|
or "bidirect"(or "bidirectional"). When "bidirect", the way to merge
|
|
|
|
|
outputs of forward and backward is concatenating. Defaults to "forward".
|
|
|
|
|
time_major (bool, optional): Whether the first dimension of the input
|
|
|
|
|
means the time steps. Defaults to False.
|
|
|
|
@ -1300,8 +1300,8 @@ class GRU(RNNBase):
|
|
|
|
|
input_size (int): The input size for the first layer's cell.
|
|
|
|
|
hidden_size (int): The hidden size for each layer's cell.
|
|
|
|
|
num_layers (int, optional): Number of layers. Defaults to 1.
|
|
|
|
|
direction (str, optional): The direction of the network. It can be "forward",
|
|
|
|
|
"backward" and "bidirectional". When "bidirectional", the way to merge
|
|
|
|
|
direction (str, optional): The direction of the network. It can be "forward"
|
|
|
|
|
or "bidirect"(or "bidirectional"). When "bidirect", the way to merge
|
|
|
|
|
outputs of forward and backward is concatenating. Defaults to "forward".
|
|
|
|
|
time_major (bool, optional): Whether the first dimension of the input
|
|
|
|
|
means the time steps. Defaults to False.
|
|
|
|
|