|
|
@ -414,9 +414,9 @@ class SimpleRNN(RNNMixin):
|
|
|
|
time_major=False,
|
|
|
|
time_major=False,
|
|
|
|
dtype="float64"):
|
|
|
|
dtype="float64"):
|
|
|
|
super(SimpleRNN, self).__init__()
|
|
|
|
super(SimpleRNN, self).__init__()
|
|
|
|
|
|
|
|
bidirectional_list = ["bidirectional", "bidirect"]
|
|
|
|
if direction in ["forward", "backward"]:
|
|
|
|
if direction in ["forward"]:
|
|
|
|
is_reverse = direction == "backward"
|
|
|
|
is_reverse = False
|
|
|
|
cell = SimpleRNNCell(
|
|
|
|
cell = SimpleRNNCell(
|
|
|
|
input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype)
|
|
|
|
input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype)
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
@ -427,7 +427,7 @@ class SimpleRNN(RNNMixin):
|
|
|
|
nonlinearity=nonlinearity,
|
|
|
|
nonlinearity=nonlinearity,
|
|
|
|
dtype=dtype)
|
|
|
|
dtype=dtype)
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
|
elif direction == "bidirectional":
|
|
|
|
elif direction in bidirectional_list:
|
|
|
|
cell_fw = SimpleRNNCell(
|
|
|
|
cell_fw = SimpleRNNCell(
|
|
|
|
input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype)
|
|
|
|
input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype)
|
|
|
|
cell_bw = SimpleRNNCell(
|
|
|
|
cell_bw = SimpleRNNCell(
|
|
|
@ -447,7 +447,7 @@ class SimpleRNN(RNNMixin):
|
|
|
|
self.input_size = input_size
|
|
|
|
self.input_size = input_size
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.dropout = dropout
|
|
|
|
self.dropout = dropout
|
|
|
|
self.num_directions = 2 if direction == "bidirectional" else 1
|
|
|
|
self.num_directions = 2 if direction in bidirectional_list else 1
|
|
|
|
self.time_major = time_major
|
|
|
|
self.time_major = time_major
|
|
|
|
self.num_layers = num_layers
|
|
|
|
self.num_layers = num_layers
|
|
|
|
self.state_components = 1
|
|
|
|
self.state_components = 1
|
|
|
@ -464,14 +464,15 @@ class LSTM(RNNMixin):
|
|
|
|
dtype="float64"):
|
|
|
|
dtype="float64"):
|
|
|
|
super(LSTM, self).__init__()
|
|
|
|
super(LSTM, self).__init__()
|
|
|
|
|
|
|
|
|
|
|
|
if direction in ["forward", "backward"]:
|
|
|
|
bidirectional_list = ["bidirectional", "bidirect"]
|
|
|
|
is_reverse = direction == "backward"
|
|
|
|
if direction in ["forward"]:
|
|
|
|
|
|
|
|
is_reverse = False
|
|
|
|
cell = LSTMCell(input_size, hidden_size, dtype=dtype)
|
|
|
|
cell = LSTMCell(input_size, hidden_size, dtype=dtype)
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
|
for i in range(1, num_layers):
|
|
|
|
for i in range(1, num_layers):
|
|
|
|
cell = LSTMCell(hidden_size, hidden_size, dtype=dtype)
|
|
|
|
cell = LSTMCell(hidden_size, hidden_size, dtype=dtype)
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
|
elif direction == "bidirectional":
|
|
|
|
elif direction in bidirectional_list:
|
|
|
|
cell_fw = LSTMCell(input_size, hidden_size, dtype=dtype)
|
|
|
|
cell_fw = LSTMCell(input_size, hidden_size, dtype=dtype)
|
|
|
|
cell_bw = LSTMCell(input_size, hidden_size, dtype=dtype)
|
|
|
|
cell_bw = LSTMCell(input_size, hidden_size, dtype=dtype)
|
|
|
|
self.append(BiRNN(cell_fw, cell_bw, time_major))
|
|
|
|
self.append(BiRNN(cell_fw, cell_bw, time_major))
|
|
|
@ -487,7 +488,7 @@ class LSTM(RNNMixin):
|
|
|
|
self.input_size = input_size
|
|
|
|
self.input_size = input_size
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.dropout = dropout
|
|
|
|
self.dropout = dropout
|
|
|
|
self.num_directions = 2 if direction == "bidirectional" else 1
|
|
|
|
self.num_directions = 2 if direction in bidirectional_list else 1
|
|
|
|
self.time_major = time_major
|
|
|
|
self.time_major = time_major
|
|
|
|
self.num_layers = num_layers
|
|
|
|
self.num_layers = num_layers
|
|
|
|
self.state_components = 2
|
|
|
|
self.state_components = 2
|
|
|
@ -504,14 +505,15 @@ class GRU(RNNMixin):
|
|
|
|
dtype="float64"):
|
|
|
|
dtype="float64"):
|
|
|
|
super(GRU, self).__init__()
|
|
|
|
super(GRU, self).__init__()
|
|
|
|
|
|
|
|
|
|
|
|
if direction in ["forward", "backward"]:
|
|
|
|
bidirectional_list = ["bidirectional", "bidirect"]
|
|
|
|
is_reverse = direction == "backward"
|
|
|
|
if direction in ["forward"]:
|
|
|
|
|
|
|
|
is_reverse = False
|
|
|
|
cell = GRUCell(input_size, hidden_size, dtype=dtype)
|
|
|
|
cell = GRUCell(input_size, hidden_size, dtype=dtype)
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
|
for i in range(1, num_layers):
|
|
|
|
for i in range(1, num_layers):
|
|
|
|
cell = GRUCell(hidden_size, hidden_size, dtype=dtype)
|
|
|
|
cell = GRUCell(hidden_size, hidden_size, dtype=dtype)
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
|
self.append(RNN(cell, is_reverse, time_major))
|
|
|
|
elif direction == "bidirectional":
|
|
|
|
elif direction in bidirectional_list:
|
|
|
|
cell_fw = GRUCell(input_size, hidden_size, dtype=dtype)
|
|
|
|
cell_fw = GRUCell(input_size, hidden_size, dtype=dtype)
|
|
|
|
cell_bw = GRUCell(input_size, hidden_size, dtype=dtype)
|
|
|
|
cell_bw = GRUCell(input_size, hidden_size, dtype=dtype)
|
|
|
|
self.append(BiRNN(cell_fw, cell_bw, time_major))
|
|
|
|
self.append(BiRNN(cell_fw, cell_bw, time_major))
|
|
|
@ -527,7 +529,7 @@ class GRU(RNNMixin):
|
|
|
|
self.input_size = input_size
|
|
|
|
self.input_size = input_size
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.dropout = dropout
|
|
|
|
self.dropout = dropout
|
|
|
|
self.num_directions = 2 if direction == "bidirectional" else 1
|
|
|
|
self.num_directions = 2 if direction in bidirectional_list else 1
|
|
|
|
self.time_major = time_major
|
|
|
|
self.time_major = time_major
|
|
|
|
self.num_layers = num_layers
|
|
|
|
self.num_layers = num_layers
|
|
|
|
self.state_components = 1
|
|
|
|
self.state_components = 1
|
|
|
|