|
|
@ -70,15 +70,15 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name,
|
|
|
|
|
|
|
|
|
|
|
|
const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
|
|
|
|
const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
|
|
|
|
// forward
|
|
|
|
// forward
|
|
|
|
auto fw_max1 =
|
|
|
|
auto fw_reduce = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)),
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_});
|
|
|
|
input_length_, std::make_shared<CondVar>(IsParameterNode)});
|
|
|
|
auto fw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)),
|
|
|
|
auto fw_max = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)),
|
|
|
|
std::make_shared<CondVar>(IsParameterNode), fw_max1});
|
|
|
|
std::make_shared<CondVar>(IsParameterNode), fw_reduce});
|
|
|
|
|
|
|
|
|
|
|
|
auto fw_shape = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimShape)), transpose_input_});
|
|
|
|
auto fw_shape = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimShape)), transpose_input_});
|
|
|
|
auto fw_stride = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStridedSlice)), fw_shape,
|
|
|
|
auto fw_stride = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStridedSlice)), fw_shape,
|
|
|
|
std::make_shared<SeqVar>()});
|
|
|
|
std::make_shared<SeqVar>()});
|
|
|
|
auto fw_min = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMinimum)), fw_stride, fw_max2});
|
|
|
|
auto fw_min = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMinimum)), fw_stride, fw_max});
|
|
|
|
|
|
|
|
|
|
|
|
auto fw_reserve = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListReserve)),
|
|
|
|
auto fw_reserve = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListReserve)),
|
|
|
|
std::make_shared<CondVar>(IsParameterNode), fw_stride});
|
|
|
|
std::make_shared<CondVar>(IsParameterNode), fw_stride});
|
|
|
@ -100,8 +100,8 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
|
|
|
|
// backward
|
|
|
|
// backward
|
|
|
|
auto bw_reverse_seq =
|
|
|
|
auto bw_reverse_seq =
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReverseSequence)), input_, input_length_});
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReverseSequence)), input_, input_length_});
|
|
|
|
auto bw_max1 =
|
|
|
|
auto bw_max1 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_,
|
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_});
|
|
|
|
std::make_shared<CondVar>(IsParameterNode)});
|
|
|
|
auto bw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)),
|
|
|
|
auto bw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)),
|
|
|
|
std::make_shared<CondVar>(IsParameterNode), bw_max1});
|
|
|
|
std::make_shared<CondVar>(IsParameterNode), bw_max1});
|
|
|
|
auto bw_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), bw_reverse_seq,
|
|
|
|
auto bw_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), bw_reverse_seq,
|
|
|
|