From ea78e16e7499dd1d4e90b4b4d3d0c51110f1935a Mon Sep 17 00:00:00 2001 From: baihuawei Date: Tue, 2 Jun 2020 20:09:50 +0800 Subject: [PATCH] fix lstm --- .../kernel/cpu/mkldnn/lstm_cpu_kernel.cc | 59 +++-- .../ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h | 1 + .../kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc | 122 ++++++---- .../kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h | 1 + mindspore/nn/layer/lstm.py | 64 +++-- tests/st/ops/cpu/test_lstm_op.py | 229 ++++++++---------- 6 files changed, 242 insertions(+), 234 deletions(-) diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc index dab165e017..85e792763f 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc @@ -24,17 +24,20 @@ namespace kernel { void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); + has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); batch_size_ = SizeToInt(src_shape[1]); seq_len_ = SizeToInt(src_shape[0]); num_directions_ = 1; if (bidirectional_) { num_directions_ = 2; } - int gate_size = 4 * hidden_size_; + if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!"; + const int gate_size = 4 * hidden_size_; for (int i = 0; i < num_layers_; ++i) { weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); weight_h_size_ += gate_size * hidden_size_; @@ -52,11 +55,11 @@ bool LstmCPUKernel::Launch(const std::vector &inputs, auto eng = MKLKernelEngine::Get().engine(); dnnl::stream s(eng); auto formatted_md = [](dim dimensions, tag layout) { return dnnl::memory::desc{{dimensions}, dt::f32, layout}; }; + auto generic_md = [](dim dimensions) { return dnnl::memory::desc{{dimensions}, dt::f32, tag::any}; }; dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; if (bidirectional_) { direction = dnnl::rnn_direction::bidirectional_concat; } - dim src_dims = {seq_len_, batch_size_, input_size_}; dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; @@ -69,35 +72,43 @@ bool LstmCPUKernel::Launch(const std::vector &inputs, dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); - dnnl::memory::desc weights_desc = formatted_md(weights_dims, tag::ldigo); - dnnl::memory::desc weights_h_desc = formatted_md(weights_h_dims, tag::ldigo); - dnnl::memory::desc bias_desc = formatted_md(bias_dims, tag::ldgo); dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); - dnnl::lstm_forward::desc desc = - dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, - weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc); + dnnl::lstm_forward::desc desc = dnnl::lstm_forward::desc( + dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims), + generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc); auto prim_desc = dnnl::lstm_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + + // construct fw memory auto workspace_memory = dnnl::memory(prim_desc.workspace_desc(), eng); auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng); - write_to_dnnl_memory(inputs[0]->addr, src_memory); - - auto src_h_memory = dnnl::memory(prim_desc.src_iter_desc(), eng); - auto src_c_memory = dnnl::memory(prim_desc.src_iter_c_desc(), eng); - write_to_dnnl_memory(inputs[1]->addr, src_h_memory); - write_to_dnnl_memory(inputs[2]->addr, src_c_memory); - - auto weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldigo), eng); - auto weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldigo), eng); - auto bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); - write_to_dnnl_memory(inputs[3]->addr, weights_memory); - write_to_dnnl_memory(reinterpret_cast(inputs[3]->addr) + weight_size_, weights_h_memory); - write_to_dnnl_memory(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_, bias_memory); + src_memory.set_data_handle(inputs[0]->addr); + auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng); + auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng); + src_h_memory.set_data_handle(inputs[1]->addr); + src_c_memory.set_data_handle(inputs[2]->addr); + auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng); + auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng); + user_weights_memory.set_data_handle(inputs[3]->addr); + user_weights_h_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_); + auto weights_memory = dnnl::memory(prim_desc.weights_layer_desc(), eng); + auto weights_h_memory = dnnl::memory(prim_desc.weights_iter_desc(), eng); + dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory); + dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory); + auto bias_memory = dnnl::memory(prim_desc.bias_desc(), eng); + if (has_bias_) { + auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); + user_bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); + dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory); + } else { + std::vector net_bias(bias_memory.get_desc().get_size(), 0.0f); + write_to_dnnl_memory(net_bias.data(), bias_memory); + } auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng); - auto dst_h_memory = dnnl::memory(prim_desc.dst_iter_desc(), eng); - auto dst_c_memory = dnnl::memory(prim_desc.dst_iter_c_desc(), eng); + auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng); + auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng); dnnl::lstm_forward fw_layer(prim_desc); workspace_memory.set_data_handle(outputs[3]->addr); dst_memory.set_data_handle(outputs[0]->addr); @@ -113,8 +124,8 @@ bool LstmCPUKernel::Launch(const std::vector &inputs, {DNNL_ARG_DST_ITER, dst_h_memory}, {DNNL_ARG_DST_ITER_C, dst_c_memory}, {DNNL_ARG_WORKSPACE, workspace_memory}}); + s.wait(); return true; } - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h index 6cb9a1ff74..17013ec267 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h @@ -40,6 +40,7 @@ class LstmCPUKernel : public MKLCPUKernel { int seq_len_; int num_directions_; bool bidirectional_; + bool has_bias_; }; MS_REG_CPU_KERNEL(LSTM, diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc index df4744db6f..f7bc1ae293 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc @@ -28,17 +28,20 @@ namespace kernel { void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); + has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); batch_size_ = SizeToInt(src_shape[1]); seq_len_ = SizeToInt(src_shape[0]); num_directions_ = 1; if (bidirectional_) { num_directions_ = 2; } - int gate_size = 4 * hidden_size_; + if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!"; + const int gate_size = 4 * hidden_size_; for (int i = 0; i < num_layers_; ++i) { weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); weight_h_size_ += gate_size * hidden_size_; @@ -70,79 +73,92 @@ bool LSTMGradCPUKernel::Launch(const std::vector &inputs, dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); - dnnl::memory::desc weights_desc = formatted_md(weights_dims, tag::ldigo); - dnnl::memory::desc weights_h_desc = formatted_md(weights_h_dims, tag::ldigo); - dnnl::memory::desc bias_desc = formatted_md(bias_dims, tag::ldgo); dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); - - dnnl::lstm_forward::desc forward_desc = - dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, - weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc); + dnnl::lstm_forward::desc forward_desc = dnnl::lstm_forward::desc( + dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims), + generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc); auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng); - - dnnl::lstm_backward::desc backward_desc = dnnl::lstm_backward::desc( - dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims), - generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, - src_c_desc, weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc); + dnnl::lstm_backward::desc backward_desc = + dnnl::lstm_backward::desc(dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, + generic_md(weights_dims), generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, + dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims), + generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc); auto prim_backward_desc = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc); + // construct fw memory auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng); - write_to_dnnl_memory(inputs[0]->addr, src_memory); - - auto src_h_memory = dnnl::memory(prim_forward_desc.src_iter_desc(), eng); - auto src_c_memory = dnnl::memory(prim_forward_desc.src_iter_c_desc(), eng); - write_to_dnnl_memory(inputs[1]->addr, src_h_memory); - write_to_dnnl_memory(inputs[2]->addr, src_c_memory); - - auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldigo), eng); - auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldigo), eng); - auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); - write_to_dnnl_memory(inputs[3]->addr, user_weights_memory); - write_to_dnnl_memory(reinterpret_cast(inputs[3]->addr) + weight_size_, user_weights_h_memory); - write_to_dnnl_memory(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_, user_bias_memory); + src_memory.set_data_handle(inputs[0]->addr); + auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng); + auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng); + src_h_memory.set_data_handle(inputs[1]->addr); + src_c_memory.set_data_handle(inputs[2]->addr); + auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng); + auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng); + user_weights_memory.set_data_handle(inputs[3]->addr); + user_weights_h_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_); auto weights_memory = dnnl::memory(prim_backward_desc.weights_layer_desc(), eng); auto weights_h_memory = dnnl::memory(prim_backward_desc.weights_iter_desc(), eng); - auto bias_memory = dnnl::memory(prim_forward_desc.bias_desc(), eng); dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory); dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory); - dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory); + + // construct bias memory + auto bias_memory = dnnl::memory(prim_backward_desc.bias_desc(), eng); + if (has_bias_) { + auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); + user_bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); + dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory); + } else { + std::vector net_bias(bias_memory.get_desc().get_size(), 0.0f); + write_to_dnnl_memory(net_bias.data(), bias_memory); + } auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng); - write_to_dnnl_memory(reinterpret_cast(inputs[4]->addr), dst_memory); - auto dst_h_memory = dnnl::memory(prim_backward_desc.dst_iter_desc(), eng); - write_to_dnnl_memory(reinterpret_cast(inputs[5]->addr), dst_h_memory); - auto dst_c_memory = dnnl::memory(prim_backward_desc.dst_iter_c_desc(), eng); - write_to_dnnl_memory(reinterpret_cast(inputs[6]->addr), dst_c_memory); + dst_memory.set_data_handle(inputs[4]->addr); + auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng); + auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng); + dst_h_memory.set_data_handle(inputs[5]->addr); + dst_c_memory.set_data_handle(inputs[6]->addr); auto workspace_memory = dnnl::memory(prim_forward_desc.workspace_desc(), eng); - write_to_dnnl_memory(inputs[10]->addr, workspace_memory); + workspace_memory.set_data_handle(inputs[10]->addr); - // construct diff memory + // construct bw memory + std::vector net_w(weights_memory.get_desc().get_size(), 0.0f); + std::vector net_wh(weights_h_memory.get_desc().get_size(), 0.0f); auto diff_src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng); - auto diff_src_h_memory = dnnl::memory(prim_backward_desc.diff_src_iter_desc(), eng); - auto diff_src_c_memory = dnnl::memory(prim_backward_desc.diff_src_iter_c_desc(), eng); - + auto diff_src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng); + auto diff_src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng); + auto user_diff_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng); + auto user_diff_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng); auto diff_weights_memory = dnnl::memory(prim_backward_desc.diff_weights_layer_desc(), eng); auto diff_weights_h_memory = dnnl::memory(prim_backward_desc.diff_weights_iter_desc(), eng); + write_to_dnnl_memory(net_w.data(), diff_weights_memory); + write_to_dnnl_memory(net_wh.data(), diff_weights_h_memory); + auto user_diff_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); auto diff_bias_memory = dnnl::memory(prim_backward_desc.diff_bias_desc(), eng); - auto diff_dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng); - write_to_dnnl_memory(reinterpret_cast(inputs[7]->addr), diff_dst_memory); - auto diff_dst_h_memory = dnnl::memory(prim_backward_desc.diff_dst_iter_desc(), eng); - write_to_dnnl_memory(reinterpret_cast(inputs[8]->addr), diff_dst_h_memory); - auto diff_dst_c_memory = dnnl::memory(prim_backward_desc.diff_dst_iter_c_desc(), eng); - write_to_dnnl_memory(reinterpret_cast(inputs[9]->addr), diff_dst_c_memory); + write_to_dnnl_memory(net_w.data(), diff_bias_memory); + auto diff_dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng); + diff_dst_memory.set_data_handle(inputs[7]->addr); + auto diff_dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng); + diff_dst_h_memory.set_data_handle(inputs[8]->addr); + auto diff_dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng); + diff_dst_c_memory.set_data_handle(inputs[9]->addr); diff_src_memory.set_data_handle(outputs[0]->addr); diff_src_h_memory.set_data_handle(outputs[1]->addr); diff_src_c_memory.set_data_handle(outputs[2]->addr); - diff_weights_memory.set_data_handle(outputs[3]->addr); - diff_weights_h_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_); - diff_bias_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_ + weight_h_size_); + user_diff_weights_memory.set_data_handle(outputs[3]->addr); + user_diff_weights_h_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_); + write_to_dnnl_memory(net_w.data(), user_diff_weights_memory); + write_to_dnnl_memory(net_wh.data(), user_diff_weights_h_memory); + + // construct bw bias memory + user_diff_bias_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_ + weight_h_size_); + write_to_dnnl_memory(net_w.data(), user_diff_bias_memory); dnnl::lstm_backward bwd_layer(prim_backward_desc); bwd_layer.execute(s, {{DNNL_ARG_SRC_LAYER, src_memory}, {DNNL_ARG_SRC_ITER, src_h_memory}, @@ -163,6 +179,16 @@ bool LSTMGradCPUKernel::Launch(const std::vector &inputs, {DNNL_ARG_DIFF_DST_ITER, diff_dst_h_memory}, {DNNL_ARG_DIFF_DST_ITER_C, diff_dst_c_memory}, {DNNL_ARG_WORKSPACE, workspace_memory}}); + dnnl::reorder(diff_weights_memory, user_diff_weights_memory) + .execute(s, diff_weights_memory, user_diff_weights_memory); + dnnl::reorder(diff_weights_h_memory, user_diff_weights_h_memory) + .execute(s, diff_weights_h_memory, user_diff_weights_h_memory); + if (has_bias_) { + dnnl::reorder(diff_bias_memory, user_diff_bias_memory).execute(s, diff_bias_memory, user_diff_bias_memory); + } else { + write_to_dnnl_memory(net_w.data(), user_diff_bias_memory); + } + s.wait(); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h index 22ec1f62db..a7bc204e7d 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h @@ -41,6 +41,7 @@ class LSTMGradCPUKernel : public MKLCPUKernel { int seq_len_; int num_directions_; bool bidirectional_; + bool has_bias_; }; MS_REG_CPU_KERNEL(LSTMGrad, diff --git a/mindspore/nn/layer/lstm.py b/mindspore/nn/layer/lstm.py index 6122e82aaa..998ca70cb7 100755 --- a/mindspore/nn/layer/lstm.py +++ b/mindspore/nn/layer/lstm.py @@ -15,7 +15,7 @@ """lstm""" from mindspore.ops import operations as P from mindspore.nn.cell import Cell -from mindspore.common.parameter import Parameter +from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.initializer import initializer from mindspore._checkparam import Validator as validator from mindspore import context @@ -149,21 +149,28 @@ class LSTM(Cell): weight_size += increment_size * num_directions self.weight = Parameter(initializer(0.0, [weight_size, 1, 1]), name='weight') else: - layer = [] - layer.append(nn.LSTMCell(input_size=self.input_size, - hidden_size=self.hidden_size, - layer_index=0, - has_bias=self.has_bias, - bidirectional=self.bidirectional, - dropout=self.dropout)) - for i in range(num_layers - 1): - layer.append(nn.LSTMCell(input_size=self.hidden_size * num_directions, - hidden_size=self.hidden_size, - layer_index=i + 1, - has_bias=self.has_bias, - bidirectional=self.bidirectional, - dropout=self.dropout)) - self.lstms = layer + input_size_list = [] + input_size_list.append(self.input_size) + for i in range(self.num_layers - 1): + input_size_list.append(self.hidden_size * num_directions) + weights = [] + layers = [] + bias_size = 0 if not self.has_bias else num_directions * self.hidden_size * 4 + for i in range(num_layers): + weight_size = (input_size_list[i] + self.hidden_size) * num_directions * self.hidden_size * 4 + w_np = np.ones([weight_size, 1, 1]).astype(np.float32) * 0.01 + if has_bias: + bias_np = np.zeros([bias_size, 1, 1]).astype(np.float32) + w_np = np.concatenate([w_np, bias_np], axis=0) + weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name='weight' + str(i))) + + layers.append(nn.LSTMCell(input_size=input_size_list[i], + hidden_size=self.hidden_size, + has_bias=self.has_bias, + bidirectional=self.bidirectional, + dropout=self.dropout)) + self.lstms = layers + self.weight = ParameterTuple(tuple(weights)) self.fill = P.Fill() self.shape = P.Shape() @@ -177,12 +184,12 @@ class LSTM(Cell): output = self.transpose2(output, (1, 0, 2)) return (output, (h, c)) h, c = hx - output, hn, cn, _, _ = self.lstms[0](x, h[0], c[0]) + output, hn, cn, _, _ = self.lstms[0](x, h[0], c[0], self.weight[0]) for i in range(1, self.num_layers): - output, hn, cn, _, _ = self.lstms[i](output, h[i], c[i]) + output, hn, cn, _, _ = self.lstms[i](output, h[i], c[i], self.weight[i]) if self.batch_first: output = self.transpose2(output, (1, 0, 2)) - return output, hn, cn, _, _ + return (output, (hn, cn)) class LSTMCell(Cell): @@ -271,11 +278,9 @@ class LSTMCell(Cell): >>> output, hn, cn, _, _ = net(input, h0, c0) """ - def __init__(self, input_size, hidden_size, - layer_index=0, has_bias=True, batch_first=False, dropout=0, @@ -283,8 +288,6 @@ class LSTMCell(Cell): super(LSTMCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size - self.num_layers = 1 - self.layer_index = layer_index self.has_bias = has_bias self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) self.dropout = float(dropout) @@ -295,16 +298,7 @@ class LSTMCell(Cell): if self.batch_first: self.transpose1 = P.Transpose() self.transpose2 = P.Transpose() - w_np = np.ones([(self.input_size + self.hidden_size) * self.num_directions * self.hidden_size * 4, 1]).astype( - np.float32) * 0.01 - if has_bias: - b_np = np.ones([self.num_directions * self.hidden_size * 4, 1]).astype( - np.float32) * 0.01 - else: - b_np = np.zeros([self.num_directions * self.hidden_size * 4, 1]).astype( - np.float32) * 0.01 - wb_np = np.concatenate((w_np, b_np), axis=0).reshape([-1, 1, 1]) - self.w = Parameter(initializer(Tensor(wb_np), wb_np.shape), name='w' + str(self.layer_index)) + self.lstm = P.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=1, @@ -312,10 +306,10 @@ class LSTMCell(Cell): bidirectional=self.bidirectional, dropout=self.dropout) - def construct(self, x, h, c): + def construct(self, x, h, c, w): if self.batch_first: x = self.transpose1(x, (1, 0, 2)) - output, hn, cn, _, _ = self.lstm(x, h, c, self.w) + output, hn, cn, _, _ = self.lstm(x, h, c, w) if self.batch_first: output = self.transpose2(output, (1, 0, 2)) return output, hn, cn, _, _ diff --git a/tests/st/ops/cpu/test_lstm_op.py b/tests/st/ops/cpu/test_lstm_op.py index 2115e46a16..c587c6b49a 100644 --- a/tests/st/ops/cpu/test_lstm_op.py +++ b/tests/st/ops/cpu/test_lstm_op.py @@ -35,27 +35,22 @@ class LstmNet(nn.Cell): if bidirectional: num_directions = 2 - self.lstm = P.LSTM(input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) input_np = np.array([[[0.6755, -1.6607, 0.1367], [0.4276, -0.7850, -0.3758]], [[-0.6424, -0.6095, 0.6639], [0.7918, 0.4147, -0.5089]], [[-1.5612, 0.0120, -0.7289], [-0.6656, -0.6626, -0.5883]], [[-0.9667, -0.6296, -0.7310], [0.1026, -0.6821, -0.4387]], [[-0.4710, 0.6558, -0.3144], [-0.8449, -0.2184, -0.1806]] ]).astype(np.float32) - self.x = Parameter(initializer(Tensor(input_np), [seq_len, batch_size, input_size]), name='x') - - self.h = Parameter(initializer( - Tensor( - np.array([0.1, 0.1, 0.1, 0.1]).reshape((num_layers * num_directions, batch_size, hidden_size)).astype( - np.float32)), - [num_layers * num_directions, batch_size, hidden_size]), name='h') + self.x = Tensor(input_np) - self.c = Parameter(initializer( - Tensor( - np.array([0.2, 0.2, 0.2, 0.2]).reshape((num_layers * num_directions, batch_size, hidden_size)).astype( - np.float32)), - [num_layers * num_directions, batch_size, hidden_size]), name='c') + self.h = Tensor(np.array([0., 0., 0., 0.]).reshape((num_directions, batch_size, hidden_size)).astype( + np.float32)) + self.c = Tensor(np.array([0., 0., 0., 0.]).reshape((num_directions, batch_size, hidden_size)).astype( + np.float32)) + self.h = tuple((self.h,)) + self.c = tuple((self.c,)) wih = np.array([[3.4021e-01, -4.6622e-01, 4.5117e-01], [-6.4257e-02, -2.4807e-01, 1.3550e-02], # i [-3.2140e-01, 5.5578e-01, 6.3589e-01], @@ -63,7 +58,7 @@ class LstmNet(nn.Cell): [-6.9863e-01, 5.9773e-01, -3.9062e-01], [-3.0253e-01, -1.9464e-01, 7.0591e-01], [-4.0835e-01, 3.6751e-01, 4.7989e-01], - [-5.6894e-01, -5.0359e-01, 4.7491e-01]]).astype(np.float32) # .reshape([1,-1]) + [-5.6894e-01, -5.0359e-01, 4.7491e-01]]).astype(np.float32).reshape([1, -1]) whh = np.array([[-0.4820, -0.2350], [-0.1195, 0.0519], [0.2162, -0.1178], @@ -71,16 +66,16 @@ class LstmNet(nn.Cell): [0.4511, -0.3961], [-0.5962, 0.0906], [0.1867, -0.1225], - [0.1831, 0.0850]]).astype(np.float32) # .reshape([1,-1]) - wih = wih.transpose((1, 0)) - whh = whh.transpose((1, 0)) + [0.1831, 0.0850]]).astype(np.float32).reshape([1, -1]) bih = np.zeros((1, 8)).astype(np.float32) - w_np = np.concatenate((wih, whh, bih), axis=0).reshape([-1, 1, 1]) + w_np = np.concatenate((wih, whh, bih), axis=1).reshape([-1, 1, 1]) self.w = Parameter(initializer(Tensor(w_np), w_np.shape), name='w') + self.lstm.weight = ParameterTuple((self.w,)) @ms_function def construct(self): - return self.lstm(self.x, self.h, self.c, self.w) + return self.lstm(self.x, (self.h, self.c)) + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @@ -98,40 +93,41 @@ def test_lstm(): if bidirectional: num_directions = 2 net = LstmNet(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) - y, h, c, _, _ = net() + y, (h, c) = net() print(y) print(c) print(h) - expect_y = np.array([[[-0.16709016, 0.13125697], - [-0.08438572, -0.01969833]], - [[-0.2746155, 0.32764038], - [-0.06504016, -0.07770399]], - [[-0.00140004, 0.17706314], - [0.03244496, -0.10135599]], - [[0.08328028, 0.06437367], - [-0.04133911, -0.11072896]], - [[0.19004421, -0.02852732], - [0.09138509, -0.00344161]]] - ) - error = np.ones([num_layers, batch_size, hidden_size]) * 1.0e-4 - diff = y.asnumpy() - expect_y - assert np.all(diff < error) - assert np.all(-diff < error) - # - expect_h = np.array([[[0.19004421, -0.02852732], - [0.09138509, -0.00344161]]]) - - error = np.ones((num_layers * num_directions, batch_size, hidden_size)) * 1.0e-4 - diff = h.asnumpy() - expect_h - assert np.all(diff < error) - assert np.all(-diff < error) - # - expect_c = np.array([[[0.34533143, -0.06313794], - [0.169008, -0.00555446]]]) - error = np.ones((num_layers * num_directions, batch_size, hidden_size)) * 1.0e-4 - diff = c.asnumpy() - expect_c - assert np.all(diff < error) - assert np.all(-diff < error) + expect_y = [[[-0.17992045, 0.07819052], + [-0.10745212, -0.06291768]], + + [[-0.28830513, 0.30579978], + [-0.07570618, -0.08868407]], + + [[-0.00814095, 0.16889746], + [0.02814853, -0.11208838]], + + [[0.08157863, 0.06088024], + [-0.04227093, -0.11514835]], + + [[0.18908429, -0.02963362], + [0.09106826, -0.00602506]]] + expect_h = [[[0.18908429, -0.02963362], + [0.09106826, -0.00602506]]] + expect_c = [[[0.3434288, -0.06561527], + [0.16838229, -0.00972614]]] + + diff_y = y.asnumpy() - expect_y + error_y = np.ones([seq_len, batch_size, hidden_size]) * 1.0e-4 + assert np.all(diff_y < error_y) + assert np.all(-diff_y < error_y) + diff_h = h.asnumpy() - expect_h + error_h = np.ones([num_layers * num_directions, batch_size, hidden_size]) * 1.0e-4 + assert np.all(diff_h < error_h) + assert np.all(-diff_h < error_h) + diff_c = c.asnumpy() - expect_c + error_c = np.ones([num_layers * num_directions, batch_size, hidden_size]) * 1.0e-4 + assert np.all(diff_c < error_c) + assert np.all(-diff_c < error_c) class MultiLayerBiLstmNet(nn.Cell): @@ -161,22 +157,15 @@ class MultiLayerBiLstmNet(nn.Cell): [1.2223, -1.3248, 0.1207, -0.8256, 0.1816, 0.7057, -0.3105, 0.5713, 0.2804, -1.0685]]]).astype(np.float32) - self.x = Parameter(initializer(Tensor(input_np), [seq_len, batch_size, input_size]), name='x') + self.x = Tensor(input_np) - self.h0 = Parameter(initializer( - Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)), - [num_directions, batch_size, hidden_size]), name='h0') - self.c0 = Parameter(initializer( - Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)), - [num_directions, batch_size, hidden_size]), name='c0') - self.h1 = Parameter(initializer( - Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)), - [num_directions, batch_size, hidden_size]), name='h1') - self.c1 = Parameter(initializer( - Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)), - [num_directions, batch_size, hidden_size]), name='c1') - self.h = ParameterTuple((self.h0, self.h1)) - self.c = ParameterTuple((self.c0, self.c1)) + self.h0 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)) + self.c0 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)) + self.h1 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)) + self.c1 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)) + + self.h = tuple((self.h0, self.h1)) + self.c = tuple((self.c0, self.c1)) @ms_function def construct(self): @@ -202,7 +191,7 @@ def test_multi_layer_bilstm(): net = MultiLayerBiLstmNet(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) - y, h, c, _, _ = net() + y, (h, c) = net() print(y) print(h) print(c) @@ -231,66 +220,53 @@ class Net(nn.Cell): num_directions = 1 if bidirectional: num_directions = 2 - input_np = np.array([[[-0.5907, 1.0557, 1.7283, 0.6706, -1.2550, -0.5298, -0.2290, -0.6735, 0.8555, 1.4836], - [-1.7070, -0.5347, -0.9105, -0.2598, 0.0588, 1.5496, 1.0757, 0.3760, -1.2020, -0.2868]], - - [[0.0151, 0.2126, 0.8090, -0.5292, -2.5590, 0.4279, -0.3081, -1.4706, -0.0498, 1.2301], - [0.4165, -0.5391, -0.0996, 0.1928, -0.4909, -0.1255, 0.4444, -1.3687, 1.3096, 0.6553]], - - [[-0.7802, -0.2083, -0.6388, 1.3757, 0.4293, 0.5363, 0.3202, -0.6687, -1.3864, -0.2953], - [1.0799, -0.7204, 0.1130, -0.5857, -0.4855, -1.1068, 1.0126, 0.8716, 1.5460, -0.7392]], - - [[2.2645, -0.6586, -0.2227, 1.4290, -0.5006, -1.6576, -0.1793, 0.5319, 0.1360, 0.2707], - [-0.4071, 0.1575, 1.4199, -0.9156, 0.1855, 0.4947, 1.0460, -0.6365, 0.1191, -0.6374]], - - [[0.2468, 1.0815, -0.4893, 0.0664, 0.6405, -2.2967, 0.7612, 0.8759, 0.5685, -1.0999], - [-0.7272, -1.7750, -0.1164, -0.7159, 0.0061, -0.7839, -1.8329, 0.3434, -0.5634, - 0.5384]]]).astype(np.float32) - + input_np = np.array([[[0.6755, -1.6607, 0.1367], [0.4276, -0.7850, -0.3758]], + [[-0.6424, -0.6095, 0.6639], [0.7918, 0.4147, -0.5089]], + [[-1.5612, 0.0120, -0.7289], [-0.6656, -0.6626, -0.5883]], + [[-0.9667, -0.6296, -0.7310], [0.1026, -0.6821, -0.4387]], + [[-0.4710, 0.6558, -0.3144], [-0.8449, -0.2184, -0.1806]] + ]).astype(np.float32) self.x = Parameter(initializer(Tensor(input_np), [seq_len, batch_size, input_size]), name='x') - - self.h0 = Parameter(initializer( - Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)), - [num_directions, batch_size, hidden_size]), name='h0') - - self.c0 = Parameter(initializer( - Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32)), - [num_directions, batch_size, hidden_size]), name='c0') - - wih_l0 = np.array([[0.2300, 0.6668, 0.4703, 0.0425, 0.0464, 0.6825, 0.2249, -0.4315, -0.2449, 0.2964], - [-0.2811, -0.3444, 0.2557, -0.5137, -0.5518, 0.1652, -0.6720, 0.1066, 0.3586, 0.6299], - [0.5728, -0.1784, 0.5661, 0.4012, 0.3856, -0.1899, 0.3102, 0.3717, -0.5651, 0.1952], - [0.1026, -0.0527, 0.1198, -0.3080, 0.2292, 0.5757, -0.3567, -0.2731, -0.0586, -0.2849], - [0.2194, -0.1622, 0.3219, -0.3008, -0.3713, -0.3034, -0.2385, 0.0412, -0.5205, 0.0280], - [-0.5499, -0.0733, -0.5236, -0.6753, -0.7045, -0.1839, -0.1037, -0.5026, -0.4055, -0.3416], - [0.1573, -0.1301, -0.2882, -0.3464, 0.6643, 0.1980, -0.6804, 0.5359, 0.5996, 0.0124], - [-0.6436, 0.0587, -0.6520, -0.0471, 0.1667, 0.6042, 0.5752, -0.6296, -0.2976, - -0.3757]]).astype(np.float32).reshape([1, -1]) - - whh_l0 = np.array([[0.3358, 0.2790], - [-0.5355, 0.0989], - [-0.1402, 0.5120], - [0.1335, 0.1653], - [0.3533, -0.3531], - [0.4166, -0.4420], - [-0.5454, -0.1720], - [0.0041, -0.0799]]).astype(np.float32).reshape([1, -1]) - - bih_l0 = np.array([0.5518, 0.1083, 0.4829, 0.0607, -0.1770, -0.6944, 0.3059, 0.5354]).astype( - np.float32).reshape([1, -1]) - bhh_l0 = np.array([0.5025, -0.1261, -0.5405, 0.3220, -0.3441, 0.6488, -0.0284, -0.2334]).astype( - np.float32).reshape([1, -1]) - - w0_np = np.concatenate( - (wih_l0, whh_l0, bih_l0 + bhh_l0), - axis=1).reshape([-1, 1, 1]) - self.w0 = Parameter(initializer(Tensor(w0_np), w0_np.shape), name='w0') - self.lstm = P.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, - has_bias=has_bias, bidirectional=bidirectional, dropout=dropout) + self.hlist = [] + self.clist = [] + self.hlist.append(Parameter(initializer( + Tensor( + np.array([0.1, 0.1, 0.1, 0.1]).reshape((num_directions, batch_size, hidden_size)).astype( + np.float32)), + [num_directions, batch_size, hidden_size]), name='h')) + self.clist.append(Parameter(initializer( + Tensor( + np.array([0.2, 0.2, 0.2, 0.2]).reshape((num_directions, batch_size, hidden_size)).astype( + np.float32)), + [num_directions, batch_size, hidden_size]), name='c')) + self.h = ParameterTuple(tuple(self.hlist)) + self.c = ParameterTuple(tuple(self.clist)) + wih = np.array([[3.4021e-01, -4.6622e-01, 4.5117e-01], + [-6.4257e-02, -2.4807e-01, 1.3550e-02], # i + [-3.2140e-01, 5.5578e-01, 6.3589e-01], + [1.6547e-01, -7.9030e-02, -2.0045e-01], + [-6.9863e-01, 5.9773e-01, -3.9062e-01], + [-3.0253e-01, -1.9464e-01, 7.0591e-01], + [-4.0835e-01, 3.6751e-01, 4.7989e-01], + [-5.6894e-01, -5.0359e-01, 4.7491e-01]]).astype(np.float32).reshape([1, -1]) + whh = np.array([[-0.4820, -0.2350], + [-0.1195, 0.0519], + [0.2162, -0.1178], + [0.6237, 0.0711], + [0.4511, -0.3961], + [-0.5962, 0.0906], + [0.1867, -0.1225], + [0.1831, 0.0850]]).astype(np.float32).reshape([1, -1]) + bih = np.zeros((1, 8)).astype(np.float32) + w_np = np.concatenate((wih, whh, bih), axis=1).reshape([-1, 1, 1]) + self.w = Parameter(initializer(Tensor(w_np), w_np.shape), name='weight0') + self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, + has_bias=has_bias, bidirectional=bidirectional, dropout=dropout) + self.lstm.weight = ParameterTuple(tuple([self.w])) @ms_function def construct(self): - return self.lstm(self.x, self.h0, self.c0, self.w0)[0] + return self.lstm(self.x, (self.h, self.c))[0] @pytest.mark.level0 @@ -299,7 +275,7 @@ class Net(nn.Cell): def test_grad(): seq_len = 5 batch_size = 2 - input_size = 10 + input_size = 3 hidden_size = 2 num_layers = 1 has_bias = True @@ -329,7 +305,6 @@ def test_grad(): print(dcx) print(dw) -# test_multi_layer_bilstm() -# test_lstm() -# tf_lstm_test() -# test_grad() +test_multi_layer_bilstm() +test_lstm() +test_grad()