!1606 LSTM network adapt to cpu target.

Merge pull request !1606 from caojian05/ms_master_dev
pull/1606/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit c0d38e40a4

@ -17,7 +17,7 @@ import math
import numpy as np import numpy as np
from mindspore import Parameter, Tensor, nn from mindspore import Parameter, Tensor, nn, context, ParameterTuple
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops import operations as P from mindspore.ops import operations as P
@ -57,6 +57,24 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
if bidirectional: if bidirectional:
num_directions = 2 num_directions = 2
if context.get_context("device_target") == "CPU":
h_list = []
c_list = []
for i in range(num_layers):
hi = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='h' + str(i))
h_list.append(hi)
ci = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='c' + str(i))
c_list.append(ci)
h = ParameterTuple(tuple(h_list))
c = ParameterTuple(tuple(c_list))
return h, c
h = Tensor( h = Tensor(
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
c = Tensor( c = Tensor(

Loading…
Cancel
Save