|
|
|
@ -28,37 +28,8 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
using tag = dnnl::memory::format_tag;
|
|
|
|
|
using dim = dnnl::memory::dims;
|
|
|
|
|
CheckParam(kernel_node);
|
|
|
|
|
auto eng = MKLKernelEngine::Get().engine();
|
|
|
|
|
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
|
|
|
|
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
|
|
|
|
std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
|
|
|
|
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
|
|
|
|
|
input_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_size");
|
|
|
|
|
hidden_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "hidden_size");
|
|
|
|
|
num_layers_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "num_layers");
|
|
|
|
|
has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias");
|
|
|
|
|
batch_size_ = SizeToInt(src_shape[1]);
|
|
|
|
|
seq_len_ = SizeToInt(src_shape[0]);
|
|
|
|
|
num_directions_ = 1;
|
|
|
|
|
if (bidirectional_) {
|
|
|
|
|
num_directions_ = 2;
|
|
|
|
|
}
|
|
|
|
|
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "error iteration shape!";
|
|
|
|
|
}
|
|
|
|
|
if (num_layers_ <= 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
|
|
|
|
|
}
|
|
|
|
|
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "conv2d only support 3-D input!";
|
|
|
|
|
}
|
|
|
|
|
const int gate_size = 4 * hidden_size_;
|
|
|
|
|
for (int i = 0; i < num_layers_; ++i) {
|
|
|
|
|
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
|
|
|
|
|
weight_h_size_ += gate_size * hidden_size_;
|
|
|
|
|
}
|
|
|
|
|
weight_size_ = weight_size_ * num_directions_;
|
|
|
|
|
weight_h_size_ = weight_h_size_ * num_directions_;
|
|
|
|
|
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
|
|
|
|
|
if (bidirectional_) {
|
|
|
|
|
direction = dnnl::rnn_direction::bidirectional_concat;
|
|
|
|
@ -91,7 +62,14 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
dst_h_desc, dst_c_desc);
|
|
|
|
|
prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc);
|
|
|
|
|
primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_);
|
|
|
|
|
AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc());
|
|
|
|
|
AddArgumentOp(src_desc, src_h_desc, src_c_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LSTMGradCPUKernel::AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc,
|
|
|
|
|
const dnnl::memory::desc &src_c_desc, const dnnl::memory::desc &bias_desc,
|
|
|
|
|
const dnnl::memory::desc &dst_desc, const dnnl::memory::desc &dst_h_desc,
|
|
|
|
|
const dnnl::memory::desc &dst_c_desc) {
|
|
|
|
|
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc);
|
|
|
|
@ -101,7 +79,6 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc());
|
|
|
|
|
AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc);
|
|
|
|
|
AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc);
|
|
|
|
@ -113,6 +90,72 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
AddArgument(DNNL_ARG_DIFF_DST_ITER_C, dst_c_desc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
|
|
|
|
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
|
|
|
|
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
|
|
|
|
std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
|
|
|
|
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
|
|
|
|
|
input_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_size");
|
|
|
|
|
hidden_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "hidden_size");
|
|
|
|
|
num_layers_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "num_layers");
|
|
|
|
|
has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias");
|
|
|
|
|
batch_size_ = SizeToInt(src_shape[1]);
|
|
|
|
|
seq_len_ = SizeToInt(src_shape[0]);
|
|
|
|
|
num_directions_ = 1;
|
|
|
|
|
if (bidirectional_) {
|
|
|
|
|
num_directions_ = 2;
|
|
|
|
|
}
|
|
|
|
|
const int gate_size = 4 * hidden_size_;
|
|
|
|
|
for (int i = 0; i < num_layers_; ++i) {
|
|
|
|
|
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
|
|
|
|
|
weight_h_size_ += gate_size * hidden_size_;
|
|
|
|
|
}
|
|
|
|
|
weight_size_ = weight_size_ * num_directions_;
|
|
|
|
|
weight_h_size_ = weight_h_size_ * num_directions_;
|
|
|
|
|
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "error iteration shape!";
|
|
|
|
|
}
|
|
|
|
|
if (num_layers_ <= 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
|
|
|
|
|
}
|
|
|
|
|
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "lstm only support 3-D input!";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LSTMGradCPUKernel::SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
const std::vector<kernel::AddressPtr> &outputs,
|
|
|
|
|
const dnnl::memory &weights_memory, const dnnl::memory &weights_h_memory,
|
|
|
|
|
const dnnl::memory &bias_memory, const dnnl::memory &diff_weights_memory,
|
|
|
|
|
const dnnl::memory &diff_weights_h_memory,
|
|
|
|
|
const dnnl::memory &diff_bias_memory) {
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LSTMGradCPUKernel::Memset_op(const dnnl::memory &mem, string name) {
|
|
|
|
|
if (memset_s(mem.get_data_handle(), mem.get_desc().get_size(), 0, mem.get_desc().get_size())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << name << " memset error";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
const std::vector<kernel::AddressPtr> &workspace /*workspace*/,
|
|
|
|
|
const std::vector<kernel::AddressPtr> &outputs) {
|
|
|
|
@ -145,14 +188,10 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
|
|
|
|
|
user_diff_weights_memory.set_data_handle(outputs[3]->addr);
|
|
|
|
|
user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
|
|
|
|
|
if (memset_s(user_diff_weights_memory.get_data_handle(), user_diff_weights_memory.get_desc().get_size(), 0,
|
|
|
|
|
user_diff_weights_memory.get_desc().get_size())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "user weights grad memset error";
|
|
|
|
|
}
|
|
|
|
|
if (memset_s(user_diff_weights_h_memory.get_data_handle(), user_diff_weights_h_memory.get_desc().get_size(), 0,
|
|
|
|
|
user_diff_weights_h_memory.get_desc().get_size())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "user weights iter grad memset error";
|
|
|
|
|
}
|
|
|
|
|
Memset_op(user_diff_weights_memory, "user weights grad");
|
|
|
|
|
Memset_op(user_diff_weights_h_memory, "user weights iter grad");
|
|
|
|
|
Memset_op(diff_weights_memory, "weights grad");
|
|
|
|
|
Memset_op(diff_weights_h_memory, "weights iter grad");
|
|
|
|
|
if (has_bias_) {
|
|
|
|
|
diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
|
|
|
|
|
}
|
|
|
|
@ -160,33 +199,8 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
prim_backward_desc_.diff_bias_desc().get_size())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "bias grad memset error";
|
|
|
|
|
}
|
|
|
|
|
if (memset_s(diff_weights_memory.get_data_handle(), diff_weights_memory.get_desc().get_size(), 0,
|
|
|
|
|
diff_weights_memory.get_desc().get_size())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "weights grad memset error";
|
|
|
|
|
}
|
|
|
|
|
if (memset_s(diff_weights_h_memory.get_data_handle(), diff_weights_h_memory.get_desc().get_size(), 0,
|
|
|
|
|
diff_weights_h_memory.get_desc().get_size())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "weights iter grad memset error";
|
|
|
|
|
}
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle());
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr);
|
|
|
|
|
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr);
|
|
|
|
|
SetArgumentHandleOp(inputs, outputs, weights_memory, weights_h_memory, bias_memory, diff_weights_memory,
|
|
|
|
|
diff_weights_h_memory, diff_bias_memory);
|
|
|
|
|
ExecutePrimitive();
|
|
|
|
|
Reorder(&diff_weights_memory, &user_diff_weights_memory);
|
|
|
|
|
Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory);
|
|
|
|
|