|
|
|
@ -28,17 +28,20 @@ namespace kernel {
|
|
|
|
|
void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
|
|
|
|
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
|
|
|
|
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
|
|
|
|
|
input_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_size");
|
|
|
|
|
hidden_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "hidden_size");
|
|
|
|
|
num_layers_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "num_layers");
|
|
|
|
|
has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias");
|
|
|
|
|
batch_size_ = SizeToInt(src_shape[1]);
|
|
|
|
|
seq_len_ = SizeToInt(src_shape[0]);
|
|
|
|
|
num_directions_ = 1;
|
|
|
|
|
if (bidirectional_) {
|
|
|
|
|
num_directions_ = 2;
|
|
|
|
|
}
|
|
|
|
|
int gate_size = 4 * hidden_size_;
|
|
|
|
|
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) MS_LOG(EXCEPTION) << "error iteration shape!";
|
|
|
|
|
const int gate_size = 4 * hidden_size_;
|
|
|
|
|
for (int i = 0; i < num_layers_; ++i) {
|
|
|
|
|
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
|
|
|
|
|
weight_h_size_ += gate_size * hidden_size_;
|
|
|
|
@ -70,79 +73,92 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_};
|
|
|
|
|
dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
|
|
|
|
|
dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
|
|
|
|
|
|
|
|
|
|
dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc);
|
|
|
|
|
dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc);
|
|
|
|
|
dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc);
|
|
|
|
|
dnnl::memory::desc weights_desc = formatted_md(weights_dims, tag::ldigo);
|
|
|
|
|
dnnl::memory::desc weights_h_desc = formatted_md(weights_h_dims, tag::ldigo);
|
|
|
|
|
dnnl::memory::desc bias_desc = formatted_md(bias_dims, tag::ldgo);
|
|
|
|
|
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
|
|
|
|
|
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
|
|
|
|
|
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
|
|
|
|
|
|
|
|
|
|
dnnl::lstm_forward::desc forward_desc =
|
|
|
|
|
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
|
|
|
|
|
weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
|
|
|
|
|
dnnl::lstm_forward::desc forward_desc = dnnl::lstm_forward::desc(
|
|
|
|
|
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
|
|
|
|
|
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc);
|
|
|
|
|
auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng);
|
|
|
|
|
|
|
|
|
|
dnnl::lstm_backward::desc backward_desc = dnnl::lstm_backward::desc(
|
|
|
|
|
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
|
|
|
|
|
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc,
|
|
|
|
|
src_c_desc, weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
|
|
|
|
|
dnnl::lstm_backward::desc backward_desc =
|
|
|
|
|
dnnl::lstm_backward::desc(dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc,
|
|
|
|
|
generic_md(weights_dims), generic_md(weights_h_dims), generic_md(bias_dims), dst_desc,
|
|
|
|
|
dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, generic_md(weights_dims),
|
|
|
|
|
generic_md(weights_h_dims), generic_md(bias_dims), dst_desc, dst_h_desc, dst_c_desc);
|
|
|
|
|
auto prim_backward_desc = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc);
|
|
|
|
|
|
|
|
|
|
// construct fw memory
|
|
|
|
|
auto src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng);
|
|
|
|
|
write_to_dnnl_memory(inputs[0]->addr, src_memory);
|
|
|
|
|
|
|
|
|
|
auto src_h_memory = dnnl::memory(prim_forward_desc.src_iter_desc(), eng);
|
|
|
|
|
auto src_c_memory = dnnl::memory(prim_forward_desc.src_iter_c_desc(), eng);
|
|
|
|
|
write_to_dnnl_memory(inputs[1]->addr, src_h_memory);
|
|
|
|
|
write_to_dnnl_memory(inputs[2]->addr, src_c_memory);
|
|
|
|
|
|
|
|
|
|
auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldigo), eng);
|
|
|
|
|
auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldigo), eng);
|
|
|
|
|
auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
|
|
|
|
|
write_to_dnnl_memory(inputs[3]->addr, user_weights_memory);
|
|
|
|
|
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_, user_weights_h_memory);
|
|
|
|
|
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_, user_bias_memory);
|
|
|
|
|
src_memory.set_data_handle(inputs[0]->addr);
|
|
|
|
|
auto src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng);
|
|
|
|
|
auto src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng);
|
|
|
|
|
src_h_memory.set_data_handle(inputs[1]->addr);
|
|
|
|
|
src_c_memory.set_data_handle(inputs[2]->addr);
|
|
|
|
|
auto user_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng);
|
|
|
|
|
auto user_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng);
|
|
|
|
|
user_weights_memory.set_data_handle(inputs[3]->addr);
|
|
|
|
|
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
|
|
|
|
|
auto weights_memory = dnnl::memory(prim_backward_desc.weights_layer_desc(), eng);
|
|
|
|
|
auto weights_h_memory = dnnl::memory(prim_backward_desc.weights_iter_desc(), eng);
|
|
|
|
|
auto bias_memory = dnnl::memory(prim_forward_desc.bias_desc(), eng);
|
|
|
|
|
dnnl::reorder(user_weights_memory, weights_memory).execute(s, user_weights_memory, weights_memory);
|
|
|
|
|
dnnl::reorder(user_weights_h_memory, weights_h_memory).execute(s, user_weights_h_memory, weights_h_memory);
|
|
|
|
|
dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory);
|
|
|
|
|
|
|
|
|
|
// construct bias memory
|
|
|
|
|
auto bias_memory = dnnl::memory(prim_backward_desc.bias_desc(), eng);
|
|
|
|
|
if (has_bias_) {
|
|
|
|
|
auto user_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
|
|
|
|
|
user_bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
|
|
|
|
|
dnnl::reorder(user_bias_memory, bias_memory).execute(s, user_bias_memory, bias_memory);
|
|
|
|
|
} else {
|
|
|
|
|
std::vector<float> net_bias(bias_memory.get_desc().get_size(), 0.0f);
|
|
|
|
|
write_to_dnnl_memory(net_bias.data(), bias_memory);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
|
|
|
|
|
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[4]->addr), dst_memory);
|
|
|
|
|
auto dst_h_memory = dnnl::memory(prim_backward_desc.dst_iter_desc(), eng);
|
|
|
|
|
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[5]->addr), dst_h_memory);
|
|
|
|
|
auto dst_c_memory = dnnl::memory(prim_backward_desc.dst_iter_c_desc(), eng);
|
|
|
|
|
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[6]->addr), dst_c_memory);
|
|
|
|
|
dst_memory.set_data_handle(inputs[4]->addr);
|
|
|
|
|
auto dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng);
|
|
|
|
|
auto dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng);
|
|
|
|
|
dst_h_memory.set_data_handle(inputs[5]->addr);
|
|
|
|
|
dst_c_memory.set_data_handle(inputs[6]->addr);
|
|
|
|
|
auto workspace_memory = dnnl::memory(prim_forward_desc.workspace_desc(), eng);
|
|
|
|
|
write_to_dnnl_memory(inputs[10]->addr, workspace_memory);
|
|
|
|
|
workspace_memory.set_data_handle(inputs[10]->addr);
|
|
|
|
|
|
|
|
|
|
// construct diff memory
|
|
|
|
|
// construct bw memory
|
|
|
|
|
std::vector<float> net_w(weights_memory.get_desc().get_size(), 0.0f);
|
|
|
|
|
std::vector<float> net_wh(weights_h_memory.get_desc().get_size(), 0.0f);
|
|
|
|
|
auto diff_src_memory = dnnl::memory(formatted_md(src_dims, tag::tnc), eng);
|
|
|
|
|
auto diff_src_h_memory = dnnl::memory(prim_backward_desc.diff_src_iter_desc(), eng);
|
|
|
|
|
auto diff_src_c_memory = dnnl::memory(prim_backward_desc.diff_src_iter_c_desc(), eng);
|
|
|
|
|
|
|
|
|
|
auto diff_src_h_memory = dnnl::memory(formatted_md(src_h_dims, tag::ldnc), eng);
|
|
|
|
|
auto diff_src_c_memory = dnnl::memory(formatted_md(src_c_dims, tag::ldnc), eng);
|
|
|
|
|
auto user_diff_weights_memory = dnnl::memory(formatted_md(weights_dims, tag::ldgoi), eng);
|
|
|
|
|
auto user_diff_weights_h_memory = dnnl::memory(formatted_md(weights_h_dims, tag::ldgoi), eng);
|
|
|
|
|
auto diff_weights_memory = dnnl::memory(prim_backward_desc.diff_weights_layer_desc(), eng);
|
|
|
|
|
auto diff_weights_h_memory = dnnl::memory(prim_backward_desc.diff_weights_iter_desc(), eng);
|
|
|
|
|
write_to_dnnl_memory(net_w.data(), diff_weights_memory);
|
|
|
|
|
write_to_dnnl_memory(net_wh.data(), diff_weights_h_memory);
|
|
|
|
|
auto user_diff_bias_memory = dnnl::memory(formatted_md(bias_dims, tag::ldgo), eng);
|
|
|
|
|
auto diff_bias_memory = dnnl::memory(prim_backward_desc.diff_bias_desc(), eng);
|
|
|
|
|
auto diff_dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
|
|
|
|
|
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[7]->addr), diff_dst_memory);
|
|
|
|
|
auto diff_dst_h_memory = dnnl::memory(prim_backward_desc.diff_dst_iter_desc(), eng);
|
|
|
|
|
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[8]->addr), diff_dst_h_memory);
|
|
|
|
|
auto diff_dst_c_memory = dnnl::memory(prim_backward_desc.diff_dst_iter_c_desc(), eng);
|
|
|
|
|
write_to_dnnl_memory(reinterpret_cast<float *>(inputs[9]->addr), diff_dst_c_memory);
|
|
|
|
|
write_to_dnnl_memory(net_w.data(), diff_bias_memory);
|
|
|
|
|
|
|
|
|
|
auto diff_dst_memory = dnnl::memory(formatted_md(dst_dims, tag::tnc), eng);
|
|
|
|
|
diff_dst_memory.set_data_handle(inputs[7]->addr);
|
|
|
|
|
auto diff_dst_h_memory = dnnl::memory(formatted_md(dst_h_dims, tag::ldnc), eng);
|
|
|
|
|
diff_dst_h_memory.set_data_handle(inputs[8]->addr);
|
|
|
|
|
auto diff_dst_c_memory = dnnl::memory(formatted_md(dst_c_dims, tag::ldnc), eng);
|
|
|
|
|
diff_dst_c_memory.set_data_handle(inputs[9]->addr);
|
|
|
|
|
diff_src_memory.set_data_handle(outputs[0]->addr);
|
|
|
|
|
diff_src_h_memory.set_data_handle(outputs[1]->addr);
|
|
|
|
|
diff_src_c_memory.set_data_handle(outputs[2]->addr);
|
|
|
|
|
diff_weights_memory.set_data_handle(outputs[3]->addr);
|
|
|
|
|
diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
|
|
|
|
|
diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
|
|
|
|
|
user_diff_weights_memory.set_data_handle(outputs[3]->addr);
|
|
|
|
|
user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
|
|
|
|
|
write_to_dnnl_memory(net_w.data(), user_diff_weights_memory);
|
|
|
|
|
write_to_dnnl_memory(net_wh.data(), user_diff_weights_h_memory);
|
|
|
|
|
|
|
|
|
|
// construct bw bias memory
|
|
|
|
|
user_diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
|
|
|
|
|
write_to_dnnl_memory(net_w.data(), user_diff_bias_memory);
|
|
|
|
|
dnnl::lstm_backward bwd_layer(prim_backward_desc);
|
|
|
|
|
bwd_layer.execute(s, {{DNNL_ARG_SRC_LAYER, src_memory},
|
|
|
|
|
{DNNL_ARG_SRC_ITER, src_h_memory},
|
|
|
|
@ -163,6 +179,16 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
{DNNL_ARG_DIFF_DST_ITER, diff_dst_h_memory},
|
|
|
|
|
{DNNL_ARG_DIFF_DST_ITER_C, diff_dst_c_memory},
|
|
|
|
|
{DNNL_ARG_WORKSPACE, workspace_memory}});
|
|
|
|
|
dnnl::reorder(diff_weights_memory, user_diff_weights_memory)
|
|
|
|
|
.execute(s, diff_weights_memory, user_diff_weights_memory);
|
|
|
|
|
dnnl::reorder(diff_weights_h_memory, user_diff_weights_h_memory)
|
|
|
|
|
.execute(s, diff_weights_h_memory, user_diff_weights_h_memory);
|
|
|
|
|
if (has_bias_) {
|
|
|
|
|
dnnl::reorder(diff_bias_memory, user_diff_bias_memory).execute(s, diff_bias_memory, user_diff_bias_memory);
|
|
|
|
|
} else {
|
|
|
|
|
write_to_dnnl_memory(net_w.data(), user_diff_bias_memory);
|
|
|
|
|
}
|
|
|
|
|
s.wait();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} // namespace kernel
|
|
|
|
|