!13370 add lstm training attr

From: @huaweib
Reviewed-by: @zhoufeng54,@kisnwang
Signed-off-by: @kisnwang
pull/13370/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 075d737127

@ -22,6 +22,12 @@
namespace mindspore {
namespace kernel {
const int kMaxLSTMLayer = 100;
const int kOutputWorkSpaceIndex = 3;
void LstmCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node);
output_size_list_[kOutputWorkSpaceIndex] = reserve_size_;
}
void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
#ifdef PLATFORM_86
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
@ -53,12 +59,25 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
auto desc = std::make_shared<dnnl::lstm_forward::desc>(dnnl::prop_kind::forward_training, direction, src_desc,
src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc,
dst_h_desc, dst_c_desc);
if (!kernel_node->HasAttr(kAttrIsTraining)) {
MS_LOG(WARNING) << "LSTM has no attr is_training";
}
is_training = GetValue<bool>(kernel_node->GetAttr(kAttrIsTraining));
auto prop_kind = dnnl::prop_kind::forward_training;
if (!is_training) {
prop_kind = dnnl::prop_kind::forward_inference;
}
auto desc = std::make_shared<dnnl::lstm_forward::desc>(
prop_kind, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc);
prim_desc_ = dnnl::lstm_forward::primitive_desc(*desc, eng);
primitive_ = std::make_shared<dnnl::lstm_forward>(prim_desc_);
if (is_training) {
reserve_size_ = static_cast<size_t>(prim_desc_.workspace_desc().get_size());
AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc());
} else {
reserve_size_ = 1;
}
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc);
@ -68,7 +87,6 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc);
AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc());
}
void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) {
@ -140,7 +158,9 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr);
SetArgumentHandle(DNNL_ARG_WORKSPACE, outputs[3]->addr);
if (is_training) {
SetArgumentHandle(DNNL_ARG_WORKSPACE, outputs[3]->addr);
}
ExecutePrimitive();
return true;
}

@ -36,6 +36,9 @@ class LstmCPUKernel : public MKLCPUKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
void InitInputOutputSize(const CNodePtr &kernel_node) override;
private:
void CheckParam(const CNodePtr &kernel_node);
int weight_size_ = 0;
@ -48,6 +51,8 @@ class LstmCPUKernel : public MKLCPUKernel {
int num_directions_;
bool bidirectional_;
bool has_bias_;
size_t reserve_size_;
bool is_training;
dnnl::memory::dims weights_dims_;
dnnl::memory::dims weights_h_dims_;
dnnl::memory::dims bias_dims_;

@ -23,6 +23,12 @@
namespace mindspore {
namespace kernel {
const int kMaxLSTMLayer = 100;
const int kInputWorkSpaceIndex = 10;
void LSTMGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node);
input_size_list_[kInputWorkSpaceIndex] = reserve_size_;
}
void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
using tag = dnnl::memory::format_tag;
@ -61,6 +67,7 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dst_h_desc, dst_c_desc);
prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc);
primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_);
reserve_size_ = static_cast<size_t>(prim_forward_desc.workspace_desc().get_size());
AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc());
AddArgumentOp(src_desc, src_h_desc, src_c_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
}

@ -32,6 +32,9 @@ class LSTMGradCPUKernel : public MKLCPUKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
void InitInputOutputSize(const CNodePtr &kernel_node) override;
private:
void AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc,
const dnnl::memory::desc &src_c_desc, const dnnl::memory::desc &bias_desc,
@ -54,6 +57,7 @@ class LSTMGradCPUKernel : public MKLCPUKernel {
int num_directions_;
bool bidirectional_;
bool has_bias_;
size_t reserve_size_;
dnnl::memory::dims weights_dims_;
dnnl::memory::dims weights_h_dims_;
dnnl::memory::dims bias_dims_;

@ -24,6 +24,7 @@
#include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h"
#include "backend/optimizer/pass/convert_const_scalar_to_tensor.h"
#include "backend/optimizer/pass/convert_attr_to_unify_mindir.h"
#include "backend/optimizer/pass/add_training_attr.h"
#include "utils/ms_context.h"
#include "debug/anf_ir_dump.h"
@ -48,6 +49,7 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>());
common_pm->AddPass(std::make_shared<ConvertConstScalarToTensor>());
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
common_pm->AddPass(std::make_shared<AddTrainingAttr>());
optimizer->AddPassManager(common_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();

@ -0,0 +1,93 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/pass/add_training_attr.h"
#include <vector>
#include <memory>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include "ir/graph_utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace opt {
namespace {
std::unordered_map<std::string, std::unordered_set<std::string>> MarkOp{
{"LSTM", {"LSTMGradWeight", "LSTMGrad", "LSTMGradData"}}};
bool CheckOP(const FuncGraphManagerPtr &manager, const AnfNodePtr &cnode, const std::unordered_set<std::string> &set) {
for (const auto &node_index : manager->node_users()[cnode]) {
auto output = node_index.first;
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimTupleGetItem)) {
if (CheckOP(manager, output, set)) {
return true;
}
} else if (output->isa<CNode>()) {
auto name = AnfAlgo::GetCNodeName(output);
if (set.find(name) != set.end()) {
return true;
}
}
}
return false;
}
void AddAttrTraining(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
if (manager->node_users().find(cnode) == manager->node_users().end()) {
return;
}
auto set = MarkOp[AnfAlgo::GetCNodeName(cnode)];
if (CheckOP(manager, cnode, set)) {
cnode->AddAttr(kAttrIsTraining, MakeValue(true));
} else {
cnode->AddAttr(kAttrIsTraining, MakeValue(false));
}
}
} // namespace
const AnfNodePtr AddTrainingAttr::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
return nullptr;
}
if (!node->isa<CNode>()) {
return nullptr;
}
auto name = AnfAlgo::GetCNodeName(node);
auto iter = MarkOp.find(name);
if (iter == MarkOp.end()) {
return nullptr;
}
if (AnfAlgo::IsGraphKernel(node)) {
return nullptr;
} else {
auto cnode = node->cast<CNodePtr>();
AddAttrTraining(func_graph, cnode);
return cnode;
}
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_TRAINING_ATTR_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_TRAINING_ATTR_H
#include <string>
#include "ir/anf.h"
#include "utils/convert_utils.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class AddTrainingAttr : public PatternProcessPass {
public:
explicit AddTrainingAttr(bool multigraph = true) : PatternProcessPass("add_training_attr", multigraph) {}
~AddTrainingAttr() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_TRAINING_ATTR_H

@ -3885,29 +3885,7 @@ class LSTM(PrimitiveWithInfer):
y_shape = (x_shape[0], x_shape[1], self.hidden_size * self.num_directions)
# set arbitrary shape for reserved space
type_size = 4
gates_ws_ld = self.get_good_ld(self.hidden_size * 4, type_size)
states_ws_ld = self.get_good_ld(max(self.hidden_size, self.input_size), type_size)
self.ws_gates_size = self.num_layers * self.num_directions * x_shape[0] * x_shape[1] * gates_ws_ld * type_size
self.ws_states_size = (self.num_layers + 1) * self.num_directions * (x_shape[0] + 1) * x_shape[
1] * states_ws_ld * type_size
self.ws_c_states_size = (self.num_layers + 1) * self.num_directions * (x_shape[0] + 1) * x_shape[
1] * states_ws_ld * type_size
self.ws_diff_states_size = (self.num_layers + 1) * self.num_directions * (x_shape[0] + 1) * (2 + 1) * x_shape[
1] * states_ws_ld * type_size
self.ws_grid_comp_size = 0
self.page_size = 4096
current_offset = 0
current_offset += self.ws_gates_size
current_offset = self.rnd_up(current_offset, self.page_size)
current_offset += self.ws_states_size
current_offset = self.rnd_up(current_offset, self.page_size)
current_offset += self.ws_c_states_size
current_offset = self.rnd_up(current_offset, self.page_size)
current_offset += self.ws_diff_states_size
current_offset = self.rnd_up(current_offset, self.page_size)
current_offset += self.ws_grid_comp_size
reserved_shape = (current_offset, 1)
reserved_shape = (1, 1)
state_shape = (1, 1)
return (y_shape, h_shape, c_shape, reserved_shape, state_shape)
@ -3916,15 +3894,6 @@ class LSTM(PrimitiveWithInfer):
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype)
def rnd_up(self, current_offset, page_size):
return ((current_offset + page_size - 1) // page_size) * page_size
def get_good_ld(self, dim, type_size):
ld = self.rnd_up(dim, 64 // type_size)
if ld * 256 == 0:
return ld + 64 // type_size
return ld
class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer):
r"""

Loading…
Cancel
Save