|
|
|
@ -19,6 +19,7 @@
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "tools/converter/converter_flags.h"
|
|
|
|
|
#include "backend/optimizer/common/pass.h"
|
|
|
|
|
#include "include/errorcode.h"
|
|
|
|
@ -40,11 +41,54 @@ class SlicePreposePass : public Pass {
|
|
|
|
|
void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
schema::SliceT *GetSliceT(const CNodePtr &cnode);
|
|
|
|
|
bool DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode);
|
|
|
|
|
void ClearCNodeAbstractValue(const CNodePtr &cnode);
|
|
|
|
|
STATUS SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode,
|
|
|
|
|
const int index, const TransactionPtr &tr = nullptr);
|
|
|
|
|
ValueNodePtr CreateSliceValueNode(const FuncGraphPtr &graph, const std::vector<int32_t> &axes,
|
|
|
|
|
const std::vector<int32_t> &begin, const std::vector<int32_t> &size);
|
|
|
|
|
ValueNodePtr CopySliceValueNode(const FuncGraphPtr &graph, const CNodePtr &slice_cnode);
|
|
|
|
|
CNodePtr InsertSlice(const FuncGraphPtr &graph, const ValueNodePtr &slice_vnode, const CNodePtr &preceed_cnode,
|
|
|
|
|
const int index, const TransactionPtr &tr);
|
|
|
|
|
STATUS VerifySliceAttrs(const CNodePtr &slice_cnode, const int dim = -1);
|
|
|
|
|
STATUS SliceParamDeBroadcast(const CNodePtr &slice_cnode, const std::vector<int32_t> &ref_shape,
|
|
|
|
|
std::vector<int32_t> *axes, std::vector<int32_t> *begin, std::vector<int32_t> *size);
|
|
|
|
|
CNodePtr CreateReshapeCNode(const FuncGraphPtr &graph, const std::vector<int64_t> &shape,
|
|
|
|
|
const AbstractBasePtr &abstract, const CNodePtr &preceed_cnode);
|
|
|
|
|
bool SiblingsAreSameSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &output_node_list,
|
|
|
|
|
const std::vector<int32_t> &ref_shape = {});
|
|
|
|
|
int GetReshapeAbnormalAxeIn(const std::vector<int> &shape_in, const std::vector<int> &shape_out,
|
|
|
|
|
std::vector<int> *mapped_axe);
|
|
|
|
|
int GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode, const std::vector<int> &mapped_axe,
|
|
|
|
|
const std::vector<int> &shape_out, std::vector<int> *shape_out_copy,
|
|
|
|
|
bool *is_normal_mode, bool *support_abnormal_mode);
|
|
|
|
|
bool PreposeWithNormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &reshape_cnode,
|
|
|
|
|
const std::vector<int> &shape_in, const std::vector<int> &shape_out_copy,
|
|
|
|
|
const std::vector<int> &mapped_axe);
|
|
|
|
|
CNodePtr CreateSlice1ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
|
|
|
|
|
const CNodePtr &matmul_cnode, const std::vector<int> &shape_in,
|
|
|
|
|
const int abnormal_axe_in, const int count_sliced_axe_in,
|
|
|
|
|
const bool slice_at_front);
|
|
|
|
|
CNodePtr CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
|
|
|
|
|
const CNodePtr &new_reshape1_cnode, const std::vector<int64_t> &new_shape1,
|
|
|
|
|
const int abnormal_axe_in, const int count_sliced_axe_in,
|
|
|
|
|
const int count_sliced2, const bool slice_at_front);
|
|
|
|
|
bool PreposeWithAbnormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &reshape_cnode,
|
|
|
|
|
const CNodePtr &matmul_cnode, const std::vector<int> &shape_in,
|
|
|
|
|
const std::vector<int> &shape_out, const int abnormal_axe_in,
|
|
|
|
|
const int abnormal_index_out);
|
|
|
|
|
bool GetArithmeticInputInfo(const CNodePtr &arithmetic_cnode, std::vector<AnfNodePtr> *inputs,
|
|
|
|
|
std::vector<std::vector<int32_t>> *shapes, std::vector<bool> *is_default_params);
|
|
|
|
|
|
|
|
|
|
bool DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode);
|
|
|
|
|
|
|
|
|
|
bool PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &softmax_cnode);
|
|
|
|
|
bool PreposeWithReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &reshape_cnode);
|
|
|
|
|
bool PreposeWithMatmul(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &matmul_cnode);
|
|
|
|
|
bool PreposeWithFullConnection(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &fc_cnode);
|
|
|
|
|
bool PreposeWithTranspose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &transpose_cnode);
|
|
|
|
|
bool PreposeWithArithmetic(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &arithmetic_cnode);
|
|
|
|
|
bool MergeSequentialSlice(const FuncGraphPtr &graph, const CNodePtr &slice1_cnode, const CNodePtr &slice2_cnode);
|
|
|
|
|
bool MergeParallelSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &slices);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
FmkType fmk_type = lite::converter::FmkType_ONNX;
|
|
|
|
|