|
|
|
@ -23,6 +23,8 @@ namespace mindspore {
|
|
|
|
|
namespace kernel {
|
|
|
|
|
void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
using tag = dnnl::memory::format_tag;
|
|
|
|
|
using dim = dnnl::memory::dims;
|
|
|
|
|
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
|
|
|
|
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
|
|
|
|
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
|
|
|
|
@ -36,7 +38,9 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
if (bidirectional_) {
|
|
|
|
|
num_directions_ = 2;
|
|
|
|
|
}
|
|
|
|
|
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!";
|
|
|
|
|
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "error iteration shape!";
|
|
|
|
|
}
|
|
|
|
|
const int gate_size = 4 * hidden_size_;
|
|
|
|
|
for (int i = 0; i < num_layers_; ++i) {
|
|
|
|
|
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
|
|
|
|
@ -44,18 +48,8 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
}
|
|
|
|
|
weight_size_ = weight_size_ * num_directions_;
|
|
|
|
|
weight_h_size_ = weight_h_size_ * num_directions_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
|
|
|
|
const std::vector<kernel::AddressPtr> &outputs) {
|
|
|
|
|
using dt = dnnl::memory::data_type;
|
|
|
|
|
using tag = dnnl::memory::format_tag;
|
|
|
|
|
using dim = dnnl::memory::dims;
|
|
|
|
|
auto eng = MKLKernelEngine::Get().engine();
|
|
|
|
|
dnnl::stream s(eng);
|
|
|
|
|
auto formatted_md = [](dim dimensions, tag layout) { return dnnl::memory::desc{{dimensions}, dt::f32, layout}; };
|
|
|
|
|
auto generic_md = [](dim dimensions) { return dnnl::memory::desc{{dimensions}, dt::f32, tag::any}; };
|
|
|
|
|
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
|
|
|
|
|
if (bidirectional_) {
|
|
|
|
|
direction = dnnl::rnn_direction::bidirectional_concat;
|
|
|
|
@ -63,68 +57,69 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
dim src_dims = {seq_len_, batch_size_, input_size_};
|
|
|
|
|
dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
|
|
|
|
|
dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
|
|
|
|
|
dim weights_dims = {num_layers_, num_directions_, input_size_, 4, hidden_size_};
|
|
|
|
|
dim weights_h_dims = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_};
|
|
|
|
|
dim bias_dims = {num_layers_, num_directions_, 4, hidden_size_};
|
|
|
|
|
weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_};
|
|
|
|
|
weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_};
|
|
|
|
|
bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_};
|
|
|
|
|
dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_};
|
|
|
|
|
dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
|
|
|
|
|
dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
|
|
|
|
|
dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc);
|
|
|
|
|
dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc);
|
|
|
|
|
dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc);
|
|
|
|
|
dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo);
|
|
|
|
|
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
|
|
|
|
|
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
|
|
|
|
|
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
|
|
|
|
|
dnnl::lstm_forward::desc desc = dnnl::lstm_forward::desc(
|
|
|
|
|
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
|
|
|
|
|
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc);
|
|
|
|
|
auto prim_desc = dnnl::lstm_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
|
|
|
|
|
dnnl::lstm_forward::desc desc =
|
|
|
|
|
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
|
|
|
|
|
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc,
|
|
|
|
|
dst_desc, dst_h_desc, dst_c_desc);
|
|
|
|
|
prim_desc_ = dnnl::lstm_forward::primitive_desc(desc, eng);
|
|
|
|
|
primitive_ = std::make_shared<dnnl::lstm_forward>(prim_desc_);
|
|
|
|
|
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_desc_.weights_layer_desc());
|
|
|
|
|
AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_desc_.weights_iter_desc());
|
|
|
|
|
AddArgument(DNNL_ARG_BIAS, bias_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// construct fw memory
|
|
|
|
|
auto workspace_memory = dnnl::memory(prim_desc.workspace_desc(), eng);
|
|
|
|
|
auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng);
|
|
|
|
|
src_memory.set_data_handle(inputs[0]->addr);
|
|
|
|
|
auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng);
|
|
|
|
|
auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng);
|
|
|
|
|
src_h_memory.set_data_handle(inputs[1]->addr);
|
|
|
|
|
src_c_memory.set_data_handle(inputs[2]->addr);
|
|
|
|
|
auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng);
|
|
|
|
|
auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng);
|
|
|
|
|
bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
|
|
|
|
const std::vector<kernel::AddressPtr> &outputs) {
|
|
|
|
|
using dt = dnnl::memory::data_type;
|
|
|
|
|
using tag = dnnl::memory::format_tag;
|
|
|
|
|
auto eng = MKLKernelEngine::Get().engine();
|
|
|
|
|
auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
|
|
|
|
|
auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
|
|
|
|
|
auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng);
|
|
|
|
|
auto weights_h_memory = dnnl::memory(prim_desc_.weights_iter_desc(), eng);
|
|
|
|
|
user_weights_memory.set_data_handle(inputs[3]->addr);
|
|
|
|
|
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
|
|
|
|
|
auto weights_memory = dnnl::memory(prim_desc.weights_layer_desc(), eng);
|
|
|
|
|
auto weights_h_memory = dnnl::memory(prim_desc.weights_iter_desc(), eng);
|
|
|
|
|
dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory);
|
|
|
|
|
dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory);
|
|
|
|
|
|
|
|
|
|
auto bias_memory = dnnl::memory(prim_desc.bias_desc(), eng);
|
|
|
|
|
Reorder(&user_weights_memory, &weights_memory);
|
|
|
|
|
Reorder(&user_weights_h_memory, &weights_h_memory);
|
|
|
|
|
auto bias_memory = dnnl::memory(prim_desc_.bias_desc(), eng);
|
|
|
|
|
if (has_bias_) {
|
|
|
|
|
auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
|
|
|
|
|
user_bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
|
|
|
|
|
dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory);
|
|
|
|
|
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
|
|
|
|
|
} else {
|
|
|
|
|
std::vector<float> net_bias(bias_memory.get_desc().get_size(), 0.0f);
|
|
|
|
|
write_to_dnnl_memory(net_bias.data(), bias_memory);
|
|
|
|
|
std::memset(bias_memory.get_data_handle(), 0, prim_desc_.bias_desc().get_size());
|
|
|
|
|
}
|
|
|
|
|
auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
|
|
|
|
|
auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng);
|
|
|
|
|
auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng);
|
|
|
|
|
dnnl::lstm_forward fw_layer(prim_desc);
|
|
|
|
|
workspace_memory.set_data_handle(outputs[3]->addr);
|
|
|
|
|
dst_memory.set_data_handle(outputs[0]->addr);
|
|
|
|
|
dst_h_memory.set_data_handle(outputs[1]->addr);
|
|
|
|
|
dst_c_memory.set_data_handle(outputs[2]->addr);
|
|
|
|
|
fw_layer.execute(s, {{DNNL_ARG_SRC_LAYER, src_memory},
|
|
|
|
|
{DNNL_ARG_SRC_ITER, src_h_memory},
|
|
|
|
|
{DNNL_ARG_SRC_ITER_C, src_c_memory},
|
|
|
|
|
{DNNL_ARG_WEIGHTS_LAYER, weights_memory},
|
|
|
|
|
{DNNL_ARG_WEIGHTS_ITER, weights_h_memory},
|
|
|
|
|
{DNNL_ARG_BIAS, bias_memory},
|
|
|
|
|
{DNNL_ARG_DST_LAYER, dst_memory},
|
|
|
|
|
{DNNL_ARG_DST_ITER, dst_h_memory},
|
|
|
|
|
{DNNL_ARG_DST_ITER_C, dst_c_memory},
|
|
|
|
|
{DNNL_ARG_WORKSPACE, workspace_memory}});
|
|
|
|
|
s.wait();
|
|
|
|
|
// set handle
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_WORKSPACE, outputs[3]->addr);
|
|
|
|
|
ExecutePrimitive();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} // namespace kernel
|
|
|
|
|