SlicePrepose support Reshape,Matmul,FC,Transpose,Arithmetic,Slice

pull/8706/head
wangzhe 4 years ago
parent 381455638a
commit 1592cf98ca

@ -199,6 +199,7 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
} }
auto matmul_cnode = func_graph->NewCNode(matmul_inputs); auto matmul_cnode = func_graph->NewCNode(matmul_inputs);
matmul_cnode->set_fullname_with_scope("matmul_" + stack_cnode->fullname_with_scope()); matmul_cnode->set_fullname_with_scope("matmul_" + stack_cnode->fullname_with_scope());
matmul_cnode->set_abstract(stack_cnode->abstract()->Clone());
MS_LOG(INFO) << "stack node:" << stack_cnode->fullname_with_scope() << " batchmatmul fusion success"; MS_LOG(INFO) << "stack node:" << stack_cnode->fullname_with_scope() << " batchmatmul fusion success";
return matmul_cnode; return matmul_cnode;
} }

@ -324,7 +324,7 @@ const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const
} }
auto layer_norm_cnode = CreateLayerNormNode(func_graph, equiv, gamma_shape, epsilon); auto layer_norm_cnode = CreateLayerNormNode(func_graph, equiv, gamma_shape, epsilon);
layer_norm_cnode->set_abstract(add2_cnode->abstract()); layer_norm_cnode->set_abstract(add2_cnode->abstract()->Clone());
layer_norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope()); layer_norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope());
MS_LOG(INFO) << "layernorm node:" << layer_norm_cnode->fullname_with_scope() << " fusion success"; MS_LOG(INFO) << "layernorm node:" << layer_norm_cnode->fullname_with_scope() << " fusion success";
return layer_norm_cnode; return layer_norm_cnode;

@ -27,9 +27,7 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li
std::vector<int> shape(tensor->shape()); std::vector<int> shape(tensor->shape());
auto type_id = static_cast<TypeId>(tensor->data_type()); auto type_id = static_cast<TypeId>(tensor->data_type());
auto type_ptr = TypeIdToType(type_id); auto type_ptr = TypeIdToType(type_id);
std::vector<int64_t> shape_vector; std::vector<int64_t> shape_vector(shape.begin(), shape.end());
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto new_abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); auto new_abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
if (new_abstract == nullptr) { if (new_abstract == nullptr) {
MS_LOG(ERROR) << "new AbstractTensor failed"; MS_LOG(ERROR) << "new AbstractTensor failed";
@ -283,12 +281,16 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
auto primt = std::make_unique<schema::PrimitiveT>(); auto primt = std::make_unique<schema::PrimitiveT>();
if (primt == nullptr) { if (primt == nullptr) {
MS_LOG(ERROR) << "primt is nullptr"; MS_LOG(ERROR) << "primt is nullptr";
FreeTensors(&input_tensors);
FreeTensors(&output_tensors);
return false; return false;
} }
*primt = *origin_primt; *primt = *origin_primt;
auto primc = std::shared_ptr<lite::PrimitiveC>(lite::PrimitiveC::Create(primt.release())); auto primc = std::shared_ptr<lite::PrimitiveC>(lite::PrimitiveC::Create(primt.release()));
if (primc == nullptr) { if (primc == nullptr) {
MS_LOG(ERROR) << "primc is nullptr"; MS_LOG(ERROR) << "primc is nullptr";
FreeTensors(&input_tensors);
FreeTensors(&output_tensors);
return false; return false;
} }
status = primc->InferShape(input_tensors, output_tensors); status = primc->InferShape(input_tensors, output_tensors);

File diff suppressed because it is too large Load Diff

@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <string>
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
#include "backend/optimizer/common/pass.h" #include "backend/optimizer/common/pass.h"
#include "include/errorcode.h" #include "include/errorcode.h"
@ -40,11 +41,54 @@ class SlicePreposePass : public Pass {
void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; } void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; }
private: private:
schema::SliceT *GetSliceT(const CNodePtr &cnode); void ClearCNodeAbstractValue(const CNodePtr &cnode);
bool DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode);
STATUS SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode, STATUS SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode,
const int index, const TransactionPtr &tr = nullptr); const int index, const TransactionPtr &tr = nullptr);
ValueNodePtr CreateSliceValueNode(const FuncGraphPtr &graph, const std::vector<int32_t> &axes,
const std::vector<int32_t> &begin, const std::vector<int32_t> &size);
ValueNodePtr CopySliceValueNode(const FuncGraphPtr &graph, const CNodePtr &slice_cnode);
CNodePtr InsertSlice(const FuncGraphPtr &graph, const ValueNodePtr &slice_vnode, const CNodePtr &preceed_cnode,
const int index, const TransactionPtr &tr);
STATUS VerifySliceAttrs(const CNodePtr &slice_cnode, const int dim = -1);
STATUS SliceParamDeBroadcast(const CNodePtr &slice_cnode, const std::vector<int32_t> &ref_shape,
std::vector<int32_t> *axes, std::vector<int32_t> *begin, std::vector<int32_t> *size);
CNodePtr CreateReshapeCNode(const FuncGraphPtr &graph, const std::vector<int64_t> &shape,
const AbstractBasePtr &abstract, const CNodePtr &preceed_cnode);
bool SiblingsAreSameSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &output_node_list,
const std::vector<int32_t> &ref_shape = {});
int GetReshapeAbnormalAxeIn(const std::vector<int> &shape_in, const std::vector<int> &shape_out,
std::vector<int> *mapped_axe);
int GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode, const std::vector<int> &mapped_axe,
const std::vector<int> &shape_out, std::vector<int> *shape_out_copy,
bool *is_normal_mode, bool *support_abnormal_mode);
bool PreposeWithNormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &reshape_cnode,
const std::vector<int> &shape_in, const std::vector<int> &shape_out_copy,
const std::vector<int> &mapped_axe);
CNodePtr CreateSlice1ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
const CNodePtr &matmul_cnode, const std::vector<int> &shape_in,
const int abnormal_axe_in, const int count_sliced_axe_in,
const bool slice_at_front);
CNodePtr CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
const CNodePtr &new_reshape1_cnode, const std::vector<int64_t> &new_shape1,
const int abnormal_axe_in, const int count_sliced_axe_in,
const int count_sliced2, const bool slice_at_front);
bool PreposeWithAbnormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &reshape_cnode,
const CNodePtr &matmul_cnode, const std::vector<int> &shape_in,
const std::vector<int> &shape_out, const int abnormal_axe_in,
const int abnormal_index_out);
bool GetArithmeticInputInfo(const CNodePtr &arithmetic_cnode, std::vector<AnfNodePtr> *inputs,
std::vector<std::vector<int32_t>> *shapes, std::vector<bool> *is_default_params);
bool DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode);
bool PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &softmax_cnode); bool PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &softmax_cnode);
bool PreposeWithReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &reshape_cnode);
bool PreposeWithMatmul(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &matmul_cnode);
bool PreposeWithFullConnection(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &fc_cnode);
bool PreposeWithTranspose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &transpose_cnode);
bool PreposeWithArithmetic(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &arithmetic_cnode);
bool MergeSequentialSlice(const FuncGraphPtr &graph, const CNodePtr &slice1_cnode, const CNodePtr &slice2_cnode);
bool MergeParallelSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &slices);
private: private:
FmkType fmk_type = lite::converter::FmkType_ONNX; FmkType fmk_type = lite::converter::FmkType_ONNX;

Loading…
Cancel
Save