From 9c74e39b12fc83c578f2c6a92e39c2100e4e9c1c Mon Sep 17 00:00:00 2001 From: baihuawei Date: Tue, 9 Jun 2020 17:53:56 +0800 Subject: [PATCH] update cpu lstm --- .../kernel/cpu/mkldnn/lstm_cpu_kernel.cc | 111 +++++----- .../ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h | 8 +- .../kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc | 204 ++++++++---------- .../kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h | 5 +- .../ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.cc | 8 +- .../ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.h | 6 +- .../kernel/cpu/mkldnn/mkl_kernel_engine.cc | 3 + .../kernel/cpu/mkldnn/mkl_kernel_engine.h | 25 +-- mindspore/model_zoo/lstm.py | 58 +---- tests/st/ops/cpu/test_lstm_op.py | 23 +- 10 files changed, 189 insertions(+), 262 deletions(-) diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc index 85e792763f..e0cd2bc552 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc @@ -23,6 +23,8 @@ namespace mindspore { namespace kernel { void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); + using tag = dnnl::memory::format_tag; + using dim = dnnl::memory::dims; std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); @@ -36,7 +38,9 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { if (bidirectional_) { num_directions_ = 2; } - if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!"; + if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { + MS_LOG(EXCEPTION) << "error iteration shape!"; + } const int gate_size = 4 * hidden_size_; for (int i = 0; i < num_layers_; ++i) { weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); @@ -44,18 +48,8 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { } weight_size_ = weight_size_ * num_directions_; weight_h_size_ = weight_h_size_ * num_directions_; -} - -bool LstmCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - using dt = dnnl::memory::data_type; - using tag = dnnl::memory::format_tag; - using dim = dnnl::memory::dims; auto eng = MKLKernelEngine::Get().engine(); dnnl::stream s(eng); - auto formatted_md = [](dim dimensions, tag layout) { return dnnl::memory::desc{{dimensions}, dt::f32, layout}; }; - auto generic_md = [](dim dimensions) { return dnnl::memory::desc{{dimensions}, dt::f32, tag::any}; }; dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; if (bidirectional_) { direction = dnnl::rnn_direction::bidirectional_concat; @@ -63,68 +57,69 @@ bool LstmCPUKernel::Launch(const std::vector &inputs, dim src_dims = {seq_len_, batch_size_, input_size_}; dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dim weights_dims = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; - dim weights_h_dims = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; - dim bias_dims = {num_layers_, num_directions_, 4, hidden_size_}; + weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; + weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; + bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); + dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); - dnnl::lstm_forward::desc desc = dnnl::lstm_forward::desc( - dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims), - generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc); - auto prim_desc = dnnl::lstm_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + dnnl::lstm_forward::desc desc = + dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, + formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, + dst_desc, dst_h_desc, dst_c_desc); + prim_desc_ = dnnl::lstm_forward::primitive_desc(desc, eng); + primitive_ = std::make_shared(prim_desc_); + AddArgument(DNNL_ARG_SRC_LAYER, src_desc); + AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); + AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); + AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_desc_.weights_layer_desc()); + AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_desc_.weights_iter_desc()); + AddArgument(DNNL_ARG_BIAS, bias_desc); + AddArgument(DNNL_ARG_DST_LAYER, dst_desc); + AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); + AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); + AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc()); +} - // construct fw memory - auto workspace_memory = dnnl::memory(prim_desc.workspace_desc(), eng); - auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng); - src_memory.set_data_handle(inputs[0]->addr); - auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng); - auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng); - src_h_memory.set_data_handle(inputs[1]->addr); - src_c_memory.set_data_handle(inputs[2]->addr); - auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng); - auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng); +bool LstmCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + using dt = dnnl::memory::data_type; + using tag = dnnl::memory::format_tag; + auto eng = MKLKernelEngine::Get().engine(); + auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); + auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); + auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng); + auto weights_h_memory = dnnl::memory(prim_desc_.weights_iter_desc(), eng); user_weights_memory.set_data_handle(inputs[3]->addr); user_weights_h_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_); - auto weights_memory = dnnl::memory(prim_desc.weights_layer_desc(), eng); - auto weights_h_memory = dnnl::memory(prim_desc.weights_iter_desc(), eng); - dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory); - dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory); - - auto bias_memory = dnnl::memory(prim_desc.bias_desc(), eng); + Reorder(&user_weights_memory, &weights_memory); + Reorder(&user_weights_h_memory, &weights_h_memory); + auto bias_memory = dnnl::memory(prim_desc_.bias_desc(), eng); if (has_bias_) { - auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); - user_bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); - dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory); + bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); } else { - std::vector net_bias(bias_memory.get_desc().get_size(), 0.0f); - write_to_dnnl_memory(net_bias.data(), bias_memory); + std::memset(bias_memory.get_data_handle(), 0, prim_desc_.bias_desc().get_size()); } - auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng); - auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng); - auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng); - dnnl::lstm_forward fw_layer(prim_desc); - workspace_memory.set_data_handle(outputs[3]->addr); - dst_memory.set_data_handle(outputs[0]->addr); - dst_h_memory.set_data_handle(outputs[1]->addr); - dst_c_memory.set_data_handle(outputs[2]->addr); - fw_layer.execute(s, {{DNNL_ARG_SRC_LAYER, src_memory}, - {DNNL_ARG_SRC_ITER, src_h_memory}, - {DNNL_ARG_SRC_ITER_C, src_c_memory}, - {DNNL_ARG_WEIGHTS_LAYER, weights_memory}, - {DNNL_ARG_WEIGHTS_ITER, weights_h_memory}, - {DNNL_ARG_BIAS, bias_memory}, - {DNNL_ARG_DST_LAYER, dst_memory}, - {DNNL_ARG_DST_ITER, dst_h_memory}, - {DNNL_ARG_DST_ITER_C, dst_c_memory}, - {DNNL_ARG_WORKSPACE, workspace_memory}}); - s.wait(); + // set handle + SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); + SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr); + SetArgumentHandle(DNNL_ARG_WORKSPACE, outputs[3]->addr); + ExecutePrimitive(); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h index 17013ec267..f864009d5f 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ #include #include #include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" @@ -41,6 +41,10 @@ class LstmCPUKernel : public MKLCPUKernel { int num_directions_; bool bidirectional_; bool has_bias_; + dnnl::memory::dims weights_dims_; + dnnl::memory::dims weights_h_dims_; + dnnl::memory::dims bias_dims_; + dnnl::lstm_forward::primitive_desc prim_desc_; }; MS_REG_CPU_KERNEL(LSTM, diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc index f7bc1ae293..cd8b8d5b80 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc @@ -24,9 +24,11 @@ namespace mindspore { namespace kernel { - void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); + using tag = dnnl::memory::format_tag; + using dim = dnnl::memory::dims; + auto eng = MKLKernelEngine::Get().engine(); std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); @@ -40,7 +42,9 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { if (bidirectional_) { num_directions_ = 2; } - if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!"; + if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { + MS_LOG(EXCEPTION) << "error iteration shape!"; + } const int gate_size = 4 * hidden_size_; for (int i = 0; i < num_layers_; ++i) { weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); @@ -48,18 +52,6 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { } weight_size_ = weight_size_ * num_directions_; weight_h_size_ = weight_h_size_ * num_directions_; -} - -bool LSTMGradCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace /*workspace*/, - const std::vector &outputs) { - using tag = dnnl::memory::format_tag; - using dt = dnnl::memory::data_type; - using dim = dnnl::memory::dims; - auto eng = MKLKernelEngine::Get().engine(); - dnnl::stream s(eng); - auto formatted_md = [](dim dimensions, tag layout) { return dnnl::memory::desc{{dimensions}, dt::f32, layout}; }; - auto generic_md = [](dim dimensions) { return dnnl::memory::desc{{dimensions}, dt::f32, tag::any}; }; dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; if (bidirectional_) { direction = dnnl::rnn_direction::bidirectional_concat; @@ -67,128 +59,112 @@ bool LSTMGradCPUKernel::Launch(const std::vector &inputs, dim src_dims = {seq_len_, batch_size_, input_size_}; dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dim weights_dims = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; - dim weights_h_dims = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; - dim bias_dims = {num_layers_, num_directions_, 4, hidden_size_}; + weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; + weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; + bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); + dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); - dnnl::lstm_forward::desc forward_desc = dnnl::lstm_forward::desc( - dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims), - generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc); + dnnl::lstm_forward::desc forward_desc = + dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, + formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, + dst_desc, dst_h_desc, dst_c_desc); auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng); - dnnl::lstm_backward::desc backward_desc = - dnnl::lstm_backward::desc(dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, - generic_md(weights_dims), generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, - dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims), - generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc); - auto prim_backward_desc = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc); + dnnl::lstm_backward::desc backward_desc = dnnl::lstm_backward::desc( + dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), + formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, + src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, + dst_h_desc, dst_c_desc); + prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc); + primitive_ = std::make_shared(prim_backward_desc_); + AddArgument(DNNL_ARG_SRC_LAYER, src_desc); + AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); + AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); + AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_backward_desc_.weights_layer_desc()); + AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_backward_desc_.weights_iter_desc()); + AddArgument(DNNL_ARG_BIAS, bias_desc); + AddArgument(DNNL_ARG_DST_LAYER, dst_desc); + AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); + AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); + AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc()); + AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc); + AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc); + AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc); + AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, prim_backward_desc_.diff_weights_layer_desc()); + AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, prim_backward_desc_.diff_weights_iter_desc()); + AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc); + AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc); + AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc); + AddArgument(DNNL_ARG_DIFF_DST_ITER_C, dst_c_desc); +} + +bool LSTMGradCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace /*workspace*/, + const std::vector &outputs) { + using dt = dnnl::memory::data_type; + using tag = dnnl::memory::format_tag; + auto eng = MKLKernelEngine::Get().engine(); // construct fw memory - auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng); - src_memory.set_data_handle(inputs[0]->addr); - auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng); - auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng); - src_h_memory.set_data_handle(inputs[1]->addr); - src_c_memory.set_data_handle(inputs[2]->addr); - auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng); - auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng); + auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); + auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); + auto weights_memory = dnnl::memory(prim_backward_desc_.weights_layer_desc(), eng); + auto weights_h_memory = dnnl::memory(prim_backward_desc_.weights_iter_desc(), eng); + auto bias_memory = dnnl::memory(prim_backward_desc_.bias_desc(), eng); user_weights_memory.set_data_handle(inputs[3]->addr); user_weights_h_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_); - auto weights_memory = dnnl::memory(prim_backward_desc.weights_layer_desc(), eng); - auto weights_h_memory = dnnl::memory(prim_backward_desc.weights_iter_desc(), eng); - dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory); - dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory); - - // construct bias memory - auto bias_memory = dnnl::memory(prim_backward_desc.bias_desc(), eng); + Reorder(&user_weights_memory, &weights_memory); + Reorder(&user_weights_h_memory, &weights_h_memory); if (has_bias_) { - auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); - user_bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); - dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory); + bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); } else { - std::vector net_bias(bias_memory.get_desc().get_size(), 0.0f); - write_to_dnnl_memory(net_bias.data(), bias_memory); + std::memset(bias_memory.get_data_handle(), 0, prim_backward_desc_.bias_desc().get_size()); } - - auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng); - dst_memory.set_data_handle(inputs[4]->addr); - auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng); - auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng); - dst_h_memory.set_data_handle(inputs[5]->addr); - dst_c_memory.set_data_handle(inputs[6]->addr); - auto workspace_memory = dnnl::memory(prim_forward_desc.workspace_desc(), eng); - workspace_memory.set_data_handle(inputs[10]->addr); - // construct bw memory - std::vector net_w(weights_memory.get_desc().get_size(), 0.0f); - std::vector net_wh(weights_h_memory.get_desc().get_size(), 0.0f); - auto diff_src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng); - auto diff_src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng); - auto diff_src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng); - auto user_diff_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng); - auto user_diff_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng); - auto diff_weights_memory = dnnl::memory(prim_backward_desc.diff_weights_layer_desc(), eng); - auto diff_weights_h_memory = dnnl::memory(prim_backward_desc.diff_weights_iter_desc(), eng); - write_to_dnnl_memory(net_w.data(), diff_weights_memory); - write_to_dnnl_memory(net_wh.data(), diff_weights_h_memory); - auto user_diff_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng); - auto diff_bias_memory = dnnl::memory(prim_backward_desc.diff_bias_desc(), eng); - write_to_dnnl_memory(net_w.data(), diff_bias_memory); - - auto diff_dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng); - diff_dst_memory.set_data_handle(inputs[7]->addr); - auto diff_dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng); - diff_dst_h_memory.set_data_handle(inputs[8]->addr); - auto diff_dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng); - diff_dst_c_memory.set_data_handle(inputs[9]->addr); - diff_src_memory.set_data_handle(outputs[0]->addr); - diff_src_h_memory.set_data_handle(outputs[1]->addr); - diff_src_c_memory.set_data_handle(outputs[2]->addr); + auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng); + auto diff_weights_h_memory = dnnl::memory(prim_backward_desc_.diff_weights_iter_desc(), eng); + auto diff_bias_memory = dnnl::memory(prim_backward_desc_.diff_bias_desc(), eng); + auto user_diff_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); + auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); user_diff_weights_memory.set_data_handle(outputs[3]->addr); user_diff_weights_h_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_); - write_to_dnnl_memory(net_w.data(), user_diff_weights_memory); - write_to_dnnl_memory(net_wh.data(), user_diff_weights_h_memory); - - // construct bw bias memory - user_diff_bias_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_ + weight_h_size_); - write_to_dnnl_memory(net_w.data(), user_diff_bias_memory); - dnnl::lstm_backward bwd_layer(prim_backward_desc); - bwd_layer.execute(s, {{DNNL_ARG_SRC_LAYER, src_memory}, - {DNNL_ARG_SRC_ITER, src_h_memory}, - {DNNL_ARG_SRC_ITER_C, src_c_memory}, - {DNNL_ARG_WEIGHTS_LAYER, weights_memory}, - {DNNL_ARG_WEIGHTS_ITER, weights_h_memory}, - {DNNL_ARG_BIAS, bias_memory}, - {DNNL_ARG_DST_LAYER, dst_memory}, - {DNNL_ARG_DST_ITER, dst_h_memory}, - {DNNL_ARG_DST_ITER_C, dst_c_memory}, - {DNNL_ARG_DIFF_SRC_LAYER, diff_src_memory}, - {DNNL_ARG_DIFF_SRC_ITER, diff_src_h_memory}, - {DNNL_ARG_DIFF_SRC_ITER_C, diff_src_c_memory}, - {DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory}, - {DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory}, - {DNNL_ARG_DIFF_BIAS, diff_bias_memory}, - {DNNL_ARG_DIFF_DST_LAYER, diff_dst_memory}, - {DNNL_ARG_DIFF_DST_ITER, diff_dst_h_memory}, - {DNNL_ARG_DIFF_DST_ITER_C, diff_dst_c_memory}, - {DNNL_ARG_WORKSPACE, workspace_memory}}); - dnnl::reorder(diff_weights_memory, user_diff_weights_memory) - .execute(s, diff_weights_memory, user_diff_weights_memory); - dnnl::reorder(diff_weights_h_memory, user_diff_weights_h_memory) - .execute(s, diff_weights_h_memory, user_diff_weights_h_memory); + std::memset(user_diff_weights_memory.get_data_handle(), 0, user_diff_weights_memory.get_desc().get_size()); + std::memset(user_diff_weights_h_memory.get_data_handle(), 0, user_diff_weights_h_memory.get_desc().get_size()); if (has_bias_) { - dnnl::reorder(diff_bias_memory, user_diff_bias_memory).execute(s, diff_bias_memory, user_diff_bias_memory); - } else { - write_to_dnnl_memory(net_w.data(), user_diff_bias_memory); + diff_bias_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_ + weight_h_size_); } - s.wait(); + std::memset(diff_bias_memory.get_data_handle(), 0, prim_backward_desc_.diff_bias_desc().get_size()); + std::memset(diff_weights_memory.get_data_handle(), 0, diff_weights_memory.get_desc().get_size()); + std::memset(diff_weights_h_memory.get_data_handle(), 0, diff_weights_h_memory.get_desc().get_size()); + SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); + SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr); + SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr); + ExecutePrimitive(); + Reorder(&diff_weights_memory, &user_diff_weights_memory); + Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h index a7bc204e7d..1f3fb824c0 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h @@ -42,6 +42,10 @@ class LSTMGradCPUKernel : public MKLCPUKernel { int num_directions_; bool bidirectional_; bool has_bias_; + dnnl::memory::dims weights_dims_; + dnnl::memory::dims weights_h_dims_; + dnnl::memory::dims bias_dims_; + dnnl::lstm_backward::primitive_desc prim_backward_desc_; }; MS_REG_CPU_KERNEL(LSTMGrad, @@ -64,5 +68,4 @@ MS_REG_CPU_KERNEL(LSTMGrad, LSTMGradCPUKernel); } // namespace kernel } // namespace mindspore - #endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.cc index 17fca72698..a38470e3a3 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.cc @@ -98,11 +98,9 @@ void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) { } void MKLCPUKernel::ExecutePrimitive() { MKLKernelEngine::Get().Execute(primitive_, arguments_); } -void MKLCPUKernel::write_to_dnnl_memory(void *handle, const dnnl::memory &mem) { - MKLKernelEngine::Get().write_to_dnnl_memory(handle, mem); -} -void MKLCPUKernel::read_from_dnnl_memory(void *handle, const dnnl::memory &mem) { - MKLKernelEngine::Get().read_from_dnnl_memory(handle, mem); + +void MKLCPUKernel::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { + MKLKernelEngine::Get().Reorder(src_mem, dst_mem); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.h index a6b8d68627..10a860afff 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.h @@ -39,10 +39,12 @@ class MKLCPUKernel : public CPUKernel { dnnl::memory::format_tag GetDefaultFormatTag(const dnnl::memory::dims &dims) const; dnnl::memory::desc GetDefaultMemDesc(const std::vector &shape); void ExecutePrimitive(); - void write_to_dnnl_memory(void *handle, const dnnl::memory &mem); - void read_from_dnnl_memory(void *handle, const dnnl::memory &mem); std::unordered_map arguments_; std::shared_ptr primitive_{nullptr}; + inline dnnl::memory::desc formatted_md(const dnnl::memory::dims &dimensions, dnnl::memory::format_tag layout) { + return dnnl::memory::desc{{dimensions}, dnnl::memory::data_type::f32, layout}; + } + void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem); }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.cc index f5270a4e9a..5ae9791b12 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.cc +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.cc @@ -33,5 +33,8 @@ dnnl::memory MKLKernelEngine::CreateMemory(const dnnl::memory::desc &mem_desc, b return dnnl::memory(mem_desc, engine_, nullptr); } } +void MKLKernelEngine::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { + dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem); +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.h b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.h index b0eaaf405f..99e7ecdfe0 100644 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.h +++ b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.h @@ -41,30 +41,7 @@ class MKLKernelEngine { void Execute(const std::shared_ptr &primitive, const std::unordered_map &arguments); - - inline void read_from_dnnl_memory(void *handle, const dnnl::memory &mem) { - dnnl::engine eng = mem.get_engine(); - size_t bytes = mem.get_desc().get_size(); - if (eng.get_kind() == dnnl::engine::kind::cpu) { - auto dst = reinterpret_cast(handle); - uint8_t *src = reinterpret_cast(mem.get_data_handle()); - for (size_t i = 0; i < bytes; ++i) { - dst[i] = src[i]; - } - } - } - // Read from handle, write to memory - inline void write_to_dnnl_memory(void *handle, const dnnl::memory &mem) { - dnnl::engine eng = mem.get_engine(); - size_t bytes = mem.get_desc().get_size(); - if (eng.get_kind() == dnnl::engine::kind::cpu) { - auto src = reinterpret_cast(handle); - uint8_t *dst = reinterpret_cast(mem.get_data_handle()); - for (size_t i = 0; i < bytes; ++i) { - dst[i] = src[i]; - } - } - } + void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem); private: MKLKernelEngine() : engine_(dnnl::engine::kind::cpu, 0), stream_(engine_) {} diff --git a/mindspore/model_zoo/lstm.py b/mindspore/model_zoo/lstm.py index 7368bbf8e5..3b06b9399e 100644 --- a/mindspore/model_zoo/lstm.py +++ b/mindspore/model_zoo/lstm.py @@ -13,43 +13,12 @@ # limitations under the License. # ============================================================================ """LSTM.""" -import math import numpy as np -from mindspore import Parameter, Tensor, nn, context, ParameterTuple -from mindspore.common.initializer import initializer +from mindspore import Tensor, nn, context from mindspore.ops import operations as P - -def init_lstm_weight( - input_size, - hidden_size, - num_layers, - bidirectional, - has_bias=True): - """Initialize lstm weight.""" - num_directions = 1 - if bidirectional: - num_directions = 2 - - weight_size = 0 - gate_size = 4 * hidden_size - for layer in range(num_layers): - for _ in range(num_directions): - input_layer_size = input_size if layer == 0 else hidden_size * num_directions - weight_size += gate_size * input_layer_size - weight_size += gate_size * hidden_size - if has_bias: - weight_size += 2 * gate_size - - stdv = 1 / math.sqrt(hidden_size) - w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) - w = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight') - - return w - - # Initialize short-term memory (h) and long-term memory (c) to 0 def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): """init default input.""" @@ -60,19 +29,15 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): if context.get_context("device_target") == "CPU": h_list = [] c_list = [] - for i in range(num_layers): - hi = Parameter(initializer( - Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)), - [num_directions, batch_size, hidden_size] - ), name='h' + str(i)) + i = 0 + while i < num_layers: + hi = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)) h_list.append(hi) - ci = Parameter(initializer( - Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)), - [num_directions, batch_size, hidden_size] - ), name='c' + str(i)) + ci = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)) c_list.append(ci) - h = ParameterTuple(tuple(h_list)) - c = ParameterTuple(tuple(c_list)) + i = i + 1 + h = tuple(h_list) + c = tuple(c_list) return h, c h = Tensor( @@ -108,12 +73,7 @@ class SentimentNet(nn.Cell): has_bias=True, bidirectional=bidirectional, dropout=0.0) - w_init = init_lstm_weight( - embed_size, - num_hiddens, - num_layers, - bidirectional) - self.encoder.weight = w_init + self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) self.concat = P.Concat(1) diff --git a/tests/st/ops/cpu/test_lstm_op.py b/tests/st/ops/cpu/test_lstm_op.py index 773537f2c0..7992bfbf0a 100644 --- a/tests/st/ops/cpu/test_lstm_op.py +++ b/tests/st/ops/cpu/test_lstm_op.py @@ -20,7 +20,6 @@ import mindspore.context as context from mindspore.common.api import ms_function from mindspore.common.initializer import initializer from mindspore.ops import composite as C -from mindspore.ops import operations as P from mindspore.common.tensor import Tensor from mindspore.common.parameter import ParameterTuple, Parameter @@ -28,7 +27,7 @@ context.set_context(device_target='CPU') class LstmNet(nn.Cell): - def __init__(self, seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): + def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): super(LstmNet, self).__init__() num_directions = 1 @@ -92,7 +91,7 @@ def test_lstm(): num_directions = 1 if bidirectional: num_directions = 2 - net = LstmNet(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) + net = LstmNet(batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) y, (h, c) = net() print(y) print(c) @@ -131,7 +130,7 @@ def test_lstm(): class MultiLayerBiLstmNet(nn.Cell): - def __init__(self, seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): + def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): super(MultiLayerBiLstmNet, self).__init__() num_directions = 1 @@ -166,6 +165,17 @@ class MultiLayerBiLstmNet(nn.Cell): self.h = tuple((self.h0, self.h1)) self.c = tuple((self.c0, self.c1)) + input_size_list = [input_size, hidden_size * num_directions] + weights = [] + bias_size = 0 if not has_bias else num_directions * hidden_size * 4 + for i in range(num_layers): + weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4 + w_np = np.ones([weight_size, 1, 1]).astype(np.float32) * 0.02 + if has_bias: + bias_np = np.zeros([bias_size, 1, 1]).astype(np.float32) + w_np = np.concatenate([w_np, bias_np], axis=0) + weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name='weight' + str(i))) + self.lstm.weight = weights @ms_function def construct(self): @@ -176,7 +186,6 @@ class MultiLayerBiLstmNet(nn.Cell): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_multi_layer_bilstm(): - seq_len = 5 batch_size = 2 input_size = 10 hidden_size = 2 @@ -185,7 +194,7 @@ def test_multi_layer_bilstm(): bidirectional = True dropout = 0.0 - net = MultiLayerBiLstmNet(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, + net = MultiLayerBiLstmNet(batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) y, (h, c) = net() print(y) @@ -274,7 +283,7 @@ def test_grad(): input_size = 3 hidden_size = 2 num_layers = 1 - has_bias = True + has_bias = False bidirectional = False dropout = 0.0 net = Grad(Net(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout))