Merge pull request !1618 from baihuawei/cpulstm
pull/1618/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit ddfa1edefe

@ -24,17 +24,20 @@ namespace kernel {
void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
input_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_size");
hidden_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "hidden_size");
num_layers_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "num_layers");
has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias");
batch_size_ = SizeToInt(src_shape[1]);
seq_len_ = SizeToInt(src_shape[0]);
num_directions_ = 1;
if (bidirectional_) {
num_directions_ = 2;
}
int gate_size = 4 * hidden_size_;
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!";
const int gate_size = 4 * hidden_size_;
for (int i = 0; i < num_layers_; ++i) {
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
weight_h_size_ += gate_size * hidden_size_;
@ -52,11 +55,11 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
auto eng = MKLKernelEngine::Get().engine();
dnnl::stream s(eng);
auto formatted_md = [](dim dimensions, tag layout) { return dnnl::memory::desc{{dimensions}, dt::f32, layout}; };
auto generic_md = [](dim dimensions) { return dnnl::memory::desc{{dimensions}, dt::f32, tag::any}; };
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
if (bidirectional_) {
direction = dnnl::rnn_direction::bidirectional_concat;
}
dim src_dims = {seq_len_, batch_size_, input_size_};
dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
@ -69,35 +72,43 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc);
dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc);
dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc);
dnnl::memory::desc weights_desc = formatted_md(weights_dims, tag::ldigo);
dnnl::memory::desc weights_h_desc = formatted_md(weights_h_dims, tag::ldigo);
dnnl::memory::desc bias_desc = formatted_md(bias_dims, tag::ldgo);
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
dnnl::lstm_forward::desc desc =
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
dnnl::lstm_forward::desc desc = dnnl::lstm_forward::desc(
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc);
auto prim_desc = dnnl::lstm_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
// construct fw memory
auto workspace_memory = dnnl::memory(prim_desc.workspace_desc(), eng);
auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng);
write_to_dnnl_memory(inputs[0]->addr, src_memory);
auto src_h_memory = dnnl::memory(prim_desc.src_iter_desc(), eng);
auto src_c_memory = dnnl::memory(prim_desc.src_iter_c_desc(), eng);
write_to_dnnl_memory(inputs[1]->addr, src_h_memory);
write_to_dnnl_memory(inputs[2]->addr, src_c_memory);
auto weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldigo), eng);
auto weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldigo), eng);
auto bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
write_to_dnnl_memory(inputs[3]->addr, weights_memory);
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_, weights_h_memory);
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_, bias_memory);
src_memory.set_data_handle(inputs[0]->addr);
auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng);
auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng);
src_h_memory.set_data_handle(inputs[1]->addr);
src_c_memory.set_data_handle(inputs[2]->addr);
auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng);
auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng);
user_weights_memory.set_data_handle(inputs[3]->addr);
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
auto weights_memory = dnnl::memory(prim_desc.weights_layer_desc(), eng);
auto weights_h_memory = dnnl::memory(prim_desc.weights_iter_desc(), eng);
dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory);
dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory);
auto bias_memory = dnnl::memory(prim_desc.bias_desc(), eng);
if (has_bias_) {
auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
user_bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory);
} else {
std::vector<float> net_bias(bias_memory.get_desc().get_size(), 0.0f);
write_to_dnnl_memory(net_bias.data(), bias_memory);
}
auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
auto dst_h_memory = dnnl::memory(prim_desc.dst_iter_desc(), eng);
auto dst_c_memory = dnnl::memory(prim_desc.dst_iter_c_desc(), eng);
auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng);
auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng);
dnnl::lstm_forward fw_layer(prim_desc);
workspace_memory.set_data_handle(outputs[3]->addr);
dst_memory.set_data_handle(outputs[0]->addr);
@ -113,8 +124,8 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
{DNNL_ARG_DST_ITER, dst_h_memory},
{DNNL_ARG_DST_ITER_C, dst_c_memory},
{DNNL_ARG_WORKSPACE, workspace_memory}});
s.wait();
return true;
}
} // namespace kernel
} // namespace mindspore

@ -40,6 +40,7 @@ class LstmCPUKernel : public MKLCPUKernel {
int seq_len_;
int num_directions_;
bool bidirectional_;
bool has_bias_;
};
MS_REG_CPU_KERNEL(LSTM,

@ -28,17 +28,20 @@ namespace kernel {
void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
input_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_size");
hidden_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "hidden_size");
num_layers_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "num_layers");
has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias");
batch_size_ = SizeToInt(src_shape[1]);
seq_len_ = SizeToInt(src_shape[0]);
num_directions_ = 1;
if (bidirectional_) {
num_directions_ = 2;
}
int gate_size = 4 * hidden_size_;
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!";
const int gate_size = 4 * hidden_size_;
for (int i = 0; i < num_layers_; ++i) {
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
weight_h_size_ += gate_size * hidden_size_;
@ -70,79 +73,92 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_};
dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc);
dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc);
dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc);
dnnl::memory::desc weights_desc = formatted_md(weights_dims, tag::ldigo);
dnnl::memory::desc weights_h_desc = formatted_md(weights_h_dims, tag::ldigo);
dnnl::memory::desc bias_desc = formatted_md(bias_dims, tag::ldgo);
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
dnnl::lstm_forward::desc forward_desc =
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
dnnl::lstm_forward::desc forward_desc = dnnl::lstm_forward::desc(
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc);
auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng);
dnnl::lstm_backward::desc backward_desc = dnnl::lstm_backward::desc(
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc,
src_c_desc, weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
dnnl::lstm_backward::desc backward_desc =
dnnl::lstm_backward::desc(dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc,
generic_md(weights_dims), generic_md(weights_h_dims), generic_md(bias_dims), dst_desc,
dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc);
auto prim_backward_desc = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc);
// construct fw memory
auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng);
write_to_dnnl_memory(inputs[0]->addr, src_memory);
auto src_h_memory = dnnl::memory(prim_forward_desc.src_iter_desc(), eng);
auto src_c_memory = dnnl::memory(prim_forward_desc.src_iter_c_desc(), eng);
write_to_dnnl_memory(inputs[1]->addr, src_h_memory);
write_to_dnnl_memory(inputs[2]->addr, src_c_memory);
auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldigo), eng);
auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldigo), eng);
auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
write_to_dnnl_memory(inputs[3]->addr, user_weights_memory);
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_, user_weights_h_memory);
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_, user_bias_memory);
src_memory.set_data_handle(inputs[0]->addr);
auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng);
auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng);
src_h_memory.set_data_handle(inputs[1]->addr);
src_c_memory.set_data_handle(inputs[2]->addr);
auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng);
auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng);
user_weights_memory.set_data_handle(inputs[3]->addr);
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
auto weights_memory = dnnl::memory(prim_backward_desc.weights_layer_desc(), eng);
auto weights_h_memory = dnnl::memory(prim_backward_desc.weights_iter_desc(), eng);
auto bias_memory = dnnl::memory(prim_forward_desc.bias_desc(), eng);
dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory);
dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory);
dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory);
// construct bias memory
auto bias_memory = dnnl::memory(prim_backward_desc.bias_desc(), eng);
if (has_bias_) {
auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
user_bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory);
} else {
std::vector<float> net_bias(bias_memory.get_desc().get_size(), 0.0f);
write_to_dnnl_memory(net_bias.data(), bias_memory);
}
auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[4]->addr), dst_memory);
auto dst_h_memory = dnnl::memory(prim_backward_desc.dst_iter_desc(), eng);
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[5]->addr), dst_h_memory);
auto dst_c_memory = dnnl::memory(prim_backward_desc.dst_iter_c_desc(), eng);
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[6]->addr), dst_c_memory);
dst_memory.set_data_handle(inputs[4]->addr);
auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng);
auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng);
dst_h_memory.set_data_handle(inputs[5]->addr);
dst_c_memory.set_data_handle(inputs[6]->addr);
auto workspace_memory = dnnl::memory(prim_forward_desc.workspace_desc(), eng);
write_to_dnnl_memory(inputs[10]->addr, workspace_memory);
workspace_memory.set_data_handle(inputs[10]->addr);
// construct diff memory
// construct bw memory
std::vector<float> net_w(weights_memory.get_desc().get_size(), 0.0f);
std::vector<float> net_wh(weights_h_memory.get_desc().get_size(), 0.0f);
auto diff_src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng);
auto diff_src_h_memory = dnnl::memory(prim_backward_desc.diff_src_iter_desc(), eng);
auto diff_src_c_memory = dnnl::memory(prim_backward_desc.diff_src_iter_c_desc(), eng);
auto diff_src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng);
auto diff_src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng);
auto user_diff_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng);
auto user_diff_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng);
auto diff_weights_memory = dnnl::memory(prim_backward_desc.diff_weights_layer_desc(), eng);
auto diff_weights_h_memory = dnnl::memory(prim_backward_desc.diff_weights_iter_desc(), eng);
write_to_dnnl_memory(net_w.data(), diff_weights_memory);
write_to_dnnl_memory(net_wh.data(), diff_weights_h_memory);
auto user_diff_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
auto diff_bias_memory = dnnl::memory(prim_backward_desc.diff_bias_desc(), eng);
auto diff_dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[7]->addr), diff_dst_memory);
auto diff_dst_h_memory = dnnl::memory(prim_backward_desc.diff_dst_iter_desc(), eng);
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[8]->addr), diff_dst_h_memory);
auto diff_dst_c_memory = dnnl::memory(prim_backward_desc.diff_dst_iter_c_desc(), eng);
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[9]->addr), diff_dst_c_memory);
write_to_dnnl_memory(net_w.data(), diff_bias_memory);
auto diff_dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
diff_dst_memory.set_data_handle(inputs[7]->addr);
auto diff_dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng);
diff_dst_h_memory.set_data_handle(inputs[8]->addr);
auto diff_dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng);
diff_dst_c_memory.set_data_handle(inputs[9]->addr);
diff_src_memory.set_data_handle(outputs[0]->addr);
diff_src_h_memory.set_data_handle(outputs[1]->addr);
diff_src_c_memory.set_data_handle(outputs[2]->addr);
diff_weights_memory.set_data_handle(outputs[3]->addr);
diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
user_diff_weights_memory.set_data_handle(outputs[3]->addr);
user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
write_to_dnnl_memory(net_w.data(), user_diff_weights_memory);
write_to_dnnl_memory(net_wh.data(), user_diff_weights_h_memory);
// construct bw bias memory
user_diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
write_to_dnnl_memory(net_w.data(), user_diff_bias_memory);
dnnl::lstm_backward bwd_layer(prim_backward_desc);
bwd_layer.execute(s, {{DNNL_ARG_SRC_LAYER, src_memory},
{DNNL_ARG_SRC_ITER, src_h_memory},
@ -163,6 +179,16 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
{DNNL_ARG_DIFF_DST_ITER, diff_dst_h_memory},
{DNNL_ARG_DIFF_DST_ITER_C, diff_dst_c_memory},
{DNNL_ARG_WORKSPACE, workspace_memory}});
dnnl::reorder(diff_weights_memory, user_diff_weights_memory)
.execute(s, diff_weights_memory, user_diff_weights_memory);
dnnl::reorder(diff_weights_h_memory, user_diff_weights_h_memory)
.execute(s, diff_weights_h_memory, user_diff_weights_h_memory);
if (has_bias_) {
dnnl::reorder(diff_bias_memory, user_diff_bias_memory).execute(s, diff_bias_memory, user_diff_bias_memory);
} else {
write_to_dnnl_memory(net_w.data(), user_diff_bias_memory);
}
s.wait();
return true;
}
} // namespace kernel

@ -41,6 +41,7 @@ class LSTMGradCPUKernel : public MKLCPUKernel {
int seq_len_;
int num_directions_;
bool bidirectional_;
bool has_bias_;
};
MS_REG_CPU_KERNEL(LSTMGrad,

@ -15,7 +15,7 @@
"""lstm"""
from mindspore.ops import operations as P
from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer
from mindspore._checkparam import Validator as validator
from mindspore import context
@ -149,21 +149,28 @@ class LSTM(Cell):
weight_size += increment_size * num_directions
self.weight = Parameter(initializer(0.0, [weight_size, 1, 1]), name='weight')
else:
layer = []
layer.append(nn.LSTMCell(input_size=self.input_size,
hidden_size=self.hidden_size,
layer_index=0,
has_bias=self.has_bias,
bidirectional=self.bidirectional,
dropout=self.dropout))
for i in range(num_layers - 1):
layer.append(nn.LSTMCell(input_size=self.hidden_size * num_directions,
hidden_size=self.hidden_size,
layer_index=i + 1,
has_bias=self.has_bias,
bidirectional=self.bidirectional,
dropout=self.dropout))
self.lstms = layer
input_size_list = []
input_size_list.append(self.input_size)
for i in range(self.num_layers - 1):
input_size_list.append(self.hidden_size * num_directions)
weights = []
layers = []
bias_size = 0 if not self.has_bias else num_directions * self.hidden_size * 4
for i in range(num_layers):
weight_size = (input_size_list[i] + self.hidden_size) * num_directions * self.hidden_size * 4
w_np = np.ones([weight_size, 1, 1]).astype(np.float32) * 0.01
if has_bias:
bias_np = np.zeros([bias_size, 1, 1]).astype(np.float32)
w_np = np.concatenate([w_np, bias_np], axis=0)
weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name='weight' + str(i)))
layers.append(nn.LSTMCell(input_size=input_size_list[i],
hidden_size=self.hidden_size,
has_bias=self.has_bias,
bidirectional=self.bidirectional,
dropout=self.dropout))
self.lstms = layers
self.weight = ParameterTuple(tuple(weights))
self.fill = P.Fill()
self.shape = P.Shape()
@ -177,12 +184,12 @@ class LSTM(Cell):
output = self.transpose2(output, (1, 0, 2))
return (output, (h, c))
h, c = hx
output, hn, cn, _, _ = self.lstms[0](x, h[0], c[0])
output, hn, cn, _, _ = self.lstms[0](x, h[0], c[0], self.weight[0])
for i in range(1, self.num_layers):
output, hn, cn, _, _ = self.lstms[i](output, h[i], c[i])
output, hn, cn, _, _ = self.lstms[i](output, h[i], c[i], self.weight[i])
if self.batch_first:
output = self.transpose2(output, (1, 0, 2))
return output, hn, cn, _, _
return (output, (hn, cn))
class LSTMCell(Cell):
@ -271,11 +278,9 @@ class LSTMCell(Cell):
>>> output, hn, cn, _, _ = net(input, h0, c0)
"""
def __init__(self,
input_size,
hidden_size,
layer_index=0,
has_bias=True,
batch_first=False,
dropout=0,
@ -283,8 +288,6 @@ class LSTMCell(Cell):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = 1
self.layer_index = layer_index
self.has_bias = has_bias
self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
self.dropout = float(dropout)
@ -295,16 +298,7 @@ class LSTMCell(Cell):
if self.batch_first:
self.transpose1 = P.Transpose()
self.transpose2 = P.Transpose()
w_np = np.ones([(self.input_size + self.hidden_size) * self.num_directions * self.hidden_size * 4, 1]).astype(
np.float32) * 0.01
if has_bias:
b_np = np.ones([self.num_directions * self.hidden_size * 4, 1]).astype(
np.float32) * 0.01
else:
b_np = np.zeros([self.num_directions * self.hidden_size * 4, 1]).astype(
np.float32) * 0.01
wb_np = np.concatenate((w_np, b_np), axis=0).reshape([-1, 1, 1])
self.w = Parameter(initializer(Tensor(wb_np), wb_np.shape), name='w' + str(self.layer_index))
self.lstm = P.LSTM(input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=1,
@ -312,10 +306,10 @@ class LSTMCell(Cell):
bidirectional=self.bidirectional,
dropout=self.dropout)
def construct(self, x, h, c):
def construct(self, x, h, c, w):
if self.batch_first:
x = self.transpose1(x, (1, 0, 2))
output, hn, cn, _, _ = self.lstm(x, h, c, self.w)
output, hn, cn, _, _ = self.lstm(x, h, c, w)
if self.batch_first:
output = self.transpose2(output, (1, 0, 2))
return output, hn, cn, _, _

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save