|
|
|
@ -25,7 +25,6 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr size_t kWhileCommonInputsLength = 2;
|
|
|
|
|
constexpr size_t kWhileUniqInputsLength = 6;
|
|
|
|
|
constexpr size_t kCondNodesNum = 12;
|
|
|
|
|
constexpr size_t kCondCNodesNum = 4;
|
|
|
|
@ -47,16 +46,11 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name,
|
|
|
|
|
: PatternProcessPass(name, multigraph) {
|
|
|
|
|
/*
|
|
|
|
|
* vars for while input
|
|
|
|
|
* common:
|
|
|
|
|
* 0:const0 1:init_state
|
|
|
|
|
* fw_while_inputs:
|
|
|
|
|
* 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias
|
|
|
|
|
* bw_while_inputs:
|
|
|
|
|
* 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias
|
|
|
|
|
*/
|
|
|
|
|
for (size_t i = 0; i < kWhileCommonInputsLength; ++i) {
|
|
|
|
|
common_vars_.emplace_back(std::make_shared<Var>());
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < kWhileUniqInputsLength; ++i) {
|
|
|
|
|
fw_vars_.emplace_back(std::make_shared<Var>());
|
|
|
|
|
bw_vars_.emplace_back(std::make_shared<Var>());
|
|
|
|
@ -64,17 +58,16 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name,
|
|
|
|
|
input_ = std::make_shared<Var>();
|
|
|
|
|
input_length_ = std::make_shared<Var>();
|
|
|
|
|
transpose_input_ = std::make_shared<Var>();
|
|
|
|
|
fw_init_state_ = std::make_shared<Var>();
|
|
|
|
|
bw_init_state_ = std::make_shared<Var>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
|
|
|
|
|
auto const1 = std::make_shared<CondVar>(IsParameterNode);
|
|
|
|
|
auto ele_shape = std::make_shared<CondVar>(IsParameterNode);
|
|
|
|
|
|
|
|
|
|
// forward
|
|
|
|
|
auto fw_max1 =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_});
|
|
|
|
|
auto fw_max2 =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, fw_max1});
|
|
|
|
|
auto fw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)),
|
|
|
|
|
std::make_shared<CondVar>(IsParameterNode), fw_max1});
|
|
|
|
|
|
|
|
|
|
auto fw_shape =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), transpose_input_});
|
|
|
|
@ -84,32 +77,33 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), fw_stride, fw_max2});
|
|
|
|
|
|
|
|
|
|
auto fw_reserve =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape,
|
|
|
|
|
fw_stride});
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)),
|
|
|
|
|
std::make_shared<CondVar>(IsParameterNode), fw_stride});
|
|
|
|
|
auto fw_from_tensor =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)),
|
|
|
|
|
transpose_input_, ele_shape});
|
|
|
|
|
transpose_input_, std::make_shared<CondVar>(IsParameterNode)});
|
|
|
|
|
auto is_fw_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While));
|
|
|
|
|
auto fw_while = VectorRef({is_fw_while, fw_vars_[0], fw_vars_[1], common_vars_[0], fw_stride, common_vars_[0],
|
|
|
|
|
fw_reserve, common_vars_[1], fw_min, fw_from_tensor, input_length_});
|
|
|
|
|
auto fw_while = VectorRef({is_fw_while, fw_vars_[0], fw_vars_[1], std::make_shared<CondVar>(IsParameterNode),
|
|
|
|
|
fw_stride, std::make_shared<CondVar>(IsParameterNode), fw_reserve, fw_init_state_, fw_min,
|
|
|
|
|
fw_from_tensor, input_length_});
|
|
|
|
|
fw_while.insert(fw_while.end(), fw_vars_.begin() + 2, fw_vars_.end());
|
|
|
|
|
fw_while.emplace_back(common_vars_[1]);
|
|
|
|
|
fw_while.emplace_back(std::make_shared<Var>());
|
|
|
|
|
auto fw_get_item = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)),
|
|
|
|
|
fw_while, std::make_shared<Var>()});
|
|
|
|
|
auto fw_stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)),
|
|
|
|
|
fw_get_item, ele_shape});
|
|
|
|
|
auto fw_out_trans =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), fw_stack});
|
|
|
|
|
fw_get_item, std::make_shared<CondVar>(IsParameterNode)});
|
|
|
|
|
auto fw_out_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)),
|
|
|
|
|
fw_stack, std::make_shared<Var>()});
|
|
|
|
|
|
|
|
|
|
// backward
|
|
|
|
|
auto bw_reverse_seq = VectorRef(
|
|
|
|
|
{std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), input_, input_length_});
|
|
|
|
|
auto bw_max1 =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_});
|
|
|
|
|
auto bw_max2 =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, bw_max1});
|
|
|
|
|
auto bw_trans =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_reverse_seq});
|
|
|
|
|
auto bw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)),
|
|
|
|
|
std::make_shared<CondVar>(IsParameterNode), bw_max1});
|
|
|
|
|
auto bw_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)),
|
|
|
|
|
bw_reverse_seq, std::make_shared<Var>()});
|
|
|
|
|
auto bw_shape =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), bw_trans});
|
|
|
|
|
auto bw_stride =
|
|
|
|
@ -117,22 +111,23 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
|
|
|
|
|
auto bw_min =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), bw_stride, bw_max2});
|
|
|
|
|
auto bw_reserve =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape,
|
|
|
|
|
bw_stride});
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)),
|
|
|
|
|
std::make_shared<CondVar>(IsParameterNode), bw_stride});
|
|
|
|
|
auto bw_from_tensor =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)), bw_trans,
|
|
|
|
|
ele_shape});
|
|
|
|
|
std::make_shared<CondVar>(IsParameterNode)});
|
|
|
|
|
auto is_bw_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While));
|
|
|
|
|
auto bw_while = VectorRef({is_bw_while, bw_vars_[0], bw_vars_[1], common_vars_[0], bw_stride, common_vars_[0],
|
|
|
|
|
bw_reserve, common_vars_[1], bw_min, bw_from_tensor, input_length_});
|
|
|
|
|
auto bw_while = VectorRef({is_bw_while, bw_vars_[0], bw_vars_[1], std::make_shared<CondVar>(IsParameterNode),
|
|
|
|
|
bw_stride, std::make_shared<CondVar>(IsParameterNode), bw_reserve, bw_init_state_, bw_min,
|
|
|
|
|
bw_from_tensor, input_length_});
|
|
|
|
|
bw_while.insert(bw_while.end(), bw_vars_.begin() + 2, bw_vars_.end());
|
|
|
|
|
bw_while.emplace_back(common_vars_[1]);
|
|
|
|
|
bw_while.emplace_back(std::make_shared<Var>());
|
|
|
|
|
auto bw_get_item = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)),
|
|
|
|
|
bw_while, std::make_shared<Var>()});
|
|
|
|
|
auto bw_stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)),
|
|
|
|
|
bw_get_item, ele_shape});
|
|
|
|
|
auto bw_out_trans =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_stack});
|
|
|
|
|
bw_get_item, std::make_shared<CondVar>(IsParameterNode)});
|
|
|
|
|
auto bw_out_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)),
|
|
|
|
|
bw_stack, std::make_shared<Var>()});
|
|
|
|
|
auto bw_reverse1 =
|
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), bw_out_trans,
|
|
|
|
|
input_length_});
|
|
|
|
@ -416,10 +411,12 @@ STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph,
|
|
|
|
|
const AnfNodePtr &hidden_state,
|
|
|
|
|
const AnfNodePtr &fw_init_state,
|
|
|
|
|
const AnfNodePtr &bw_init_state,
|
|
|
|
|
const std::string base_name) const {
|
|
|
|
|
MS_ASSERT(func_graph);
|
|
|
|
|
MS_ASSERT(hidden_state);
|
|
|
|
|
MS_ASSERT(func_graph != nullptr);
|
|
|
|
|
MS_ASSERT(fw_init_state != nullptr);
|
|
|
|
|
MS_ASSERT(bw_init_state != nullptr);
|
|
|
|
|
auto stack_primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
std::unique_ptr<schema::StackT> attr = std::make_unique<schema::StackT>();
|
|
|
|
|
attr->axis = 0;
|
|
|
|
@ -427,9 +424,9 @@ CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &f
|
|
|
|
|
stack_primitive->value.value = attr.release();
|
|
|
|
|
auto stack_cvalue = lite::PrimitiveC::Create(stack_primitive.release());
|
|
|
|
|
auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(stack_cvalue));
|
|
|
|
|
std::vector<AnfNodePtr> new_node_inputs = {value_node, hidden_state, hidden_state};
|
|
|
|
|
std::vector<AnfNodePtr> new_node_inputs = {value_node, fw_init_state, bw_init_state};
|
|
|
|
|
auto new_node = func_graph->NewCNode(new_node_inputs);
|
|
|
|
|
new_node->set_abstract(hidden_state->abstract()->Clone());
|
|
|
|
|
new_node->set_abstract(fw_init_state->abstract()->Clone());
|
|
|
|
|
new_node->set_fullname_with_scope("stack_hidden_" + base_name);
|
|
|
|
|
return new_node;
|
|
|
|
|
}
|
|
|
|
@ -452,31 +449,33 @@ CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr
|
|
|
|
|
auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(gru_cvalue));
|
|
|
|
|
|
|
|
|
|
auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[2]]);
|
|
|
|
|
MS_ASSERT(fw_gate_kernel);
|
|
|
|
|
MS_ASSERT(fw_gate_kernel != nullptr);
|
|
|
|
|
auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[3]]);
|
|
|
|
|
MS_ASSERT(fw_gate_bias);
|
|
|
|
|
MS_ASSERT(fw_gate_bias != nullptr);
|
|
|
|
|
auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[4]]);
|
|
|
|
|
MS_ASSERT(fw_cand_kernel);
|
|
|
|
|
MS_ASSERT(fw_cand_kernel != nullptr);
|
|
|
|
|
auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[5]]);
|
|
|
|
|
MS_ASSERT(fw_cand_bias);
|
|
|
|
|
MS_ASSERT(fw_cand_bias != nullptr);
|
|
|
|
|
|
|
|
|
|
auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[2]]);
|
|
|
|
|
MS_ASSERT(bw_gate_kernel);
|
|
|
|
|
MS_ASSERT(bw_gate_kernel != nullptr);
|
|
|
|
|
auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[3]]);
|
|
|
|
|
MS_ASSERT(bw_gate_bias);
|
|
|
|
|
MS_ASSERT(bw_gate_bias != nullptr);
|
|
|
|
|
auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[4]]);
|
|
|
|
|
MS_ASSERT(bw_cand_kernel);
|
|
|
|
|
MS_ASSERT(bw_cand_kernel != nullptr);
|
|
|
|
|
auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[5]]);
|
|
|
|
|
MS_ASSERT(bw_cand_bias);
|
|
|
|
|
MS_ASSERT(bw_cand_bias != nullptr);
|
|
|
|
|
|
|
|
|
|
auto hidden = utils::cast<AnfNodePtr>((*equiv)[common_vars_[1]]);
|
|
|
|
|
MS_ASSERT(hidden);
|
|
|
|
|
auto stacked_hidden = GetStackedHiddenState(func_graph, hidden, base_name);
|
|
|
|
|
auto fw_init_state = utils::cast<AnfNodePtr>((*equiv)[fw_init_state_]);
|
|
|
|
|
MS_ASSERT(fw_init_state != nullptr);
|
|
|
|
|
auto bw_init_state = utils::cast<AnfNodePtr>((*equiv)[bw_init_state_]);
|
|
|
|
|
MS_ASSERT(bw_init_state != nullptr);
|
|
|
|
|
auto stacked_hidden = GetStackedHiddenState(func_graph, fw_init_state, bw_init_state, base_name);
|
|
|
|
|
if (stacked_hidden == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto input_length = utils::cast<AnfNodePtr>((*equiv)[input_length_]);
|
|
|
|
|
MS_ASSERT(hidden);
|
|
|
|
|
MS_ASSERT(hidden != nullptr);
|
|
|
|
|
|
|
|
|
|
int input_size = 0;
|
|
|
|
|
int hidden_size = 0;
|
|
|
|
@ -536,8 +535,8 @@ CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr
|
|
|
|
|
|
|
|
|
|
CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
|
|
|
|
|
const std::string base_name) const {
|
|
|
|
|
MS_ASSERT(func_graph);
|
|
|
|
|
MS_ASSERT(gru_output);
|
|
|
|
|
MS_ASSERT(func_graph != nullptr);
|
|
|
|
|
MS_ASSERT(gru_output != nullptr);
|
|
|
|
|
auto split_primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
std::unique_ptr<schema::SplitT> split_attr = std::make_unique<schema::SplitT>();
|
|
|
|
|
split_attr->numberSplit = 2;
|
|
|
|
@ -603,8 +602,8 @@ CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node,
|
|
|
|
|
const EquivPtr &equiv) const {
|
|
|
|
|
MS_ASSERT(func_graph);
|
|
|
|
|
MS_ASSERT(concat_node);
|
|
|
|
|
MS_ASSERT(func_graph != nullptr);
|
|
|
|
|
MS_ASSERT(concat_node != nullptr);
|
|
|
|
|
MS_LOG(DEBUG) << "bidirection tf gru fusion pass";
|
|
|
|
|
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(concat_node) != lite::RET_OK) {
|
|
|
|
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
|
|
@ -612,7 +611,7 @@ const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_gr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto transpose_input = utils::cast<AnfNodePtr>((*equiv)[transpose_input_]);
|
|
|
|
|
MS_ASSERT(transpose_input);
|
|
|
|
|
MS_ASSERT(transpose_input != nullptr);
|
|
|
|
|
if (!utils::isa<CNodePtr>(transpose_input) || GetCNodeType(transpose_input) != schema::PrimitiveType_Transpose) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|