From 1592cf98ca3b1425d953e78bf7ca193629990d6f Mon Sep 17 00:00:00 2001 From: wangzhe Date: Mon, 16 Nov 2020 19:23:53 +0800 Subject: [PATCH] SlicePrepose support Reshape,Matmul,FC,Transpose,Arithmetic,Slice --- .../optimizer/fusion/batchmatmul_fusion.cc | 1 + .../optimizer/fusion/layer_norm_fusion.cc | 2 +- .../tools/optimizer/graph/infershape_pass.cc | 8 +- .../optimizer/graph/slice_prepose_pass.cc | 1192 ++++++++++++++++- .../optimizer/graph/slice_prepose_pass.h | 48 +- 5 files changed, 1231 insertions(+), 20 deletions(-) diff --git a/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc index c5ba9ae118..77a7f6110e 100644 --- a/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/batchmatmul_fusion.cc @@ -199,6 +199,7 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons } auto matmul_cnode = func_graph->NewCNode(matmul_inputs); matmul_cnode->set_fullname_with_scope("matmul_" + stack_cnode->fullname_with_scope()); + matmul_cnode->set_abstract(stack_cnode->abstract()->Clone()); MS_LOG(INFO) << "stack node:" << stack_cnode->fullname_with_scope() << " batchmatmul fusion success"; return matmul_cnode; } diff --git a/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc b/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc index 8db4c793f6..53e3faba1f 100644 --- a/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc @@ -324,7 +324,7 @@ const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const } auto layer_norm_cnode = CreateLayerNormNode(func_graph, equiv, gamma_shape, epsilon); - layer_norm_cnode->set_abstract(add2_cnode->abstract()); + layer_norm_cnode->set_abstract(add2_cnode->abstract()->Clone()); layer_norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope()); MS_LOG(INFO) << "layernorm node:" << layer_norm_cnode->fullname_with_scope() << " fusion success"; return layer_norm_cnode; diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index c794e3a364..6dcb0ad96d 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -27,9 +27,7 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li std::vector shape(tensor->shape()); auto type_id = static_cast(tensor->data_type()); auto type_ptr = TypeIdToType(type_id); - std::vector shape_vector; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); + std::vector shape_vector(shape.begin(), shape.end()); auto new_abstract = std::make_shared(type_ptr, shape_vector); if (new_abstract == nullptr) { MS_LOG(ERROR) << "new AbstractTensor failed"; @@ -283,12 +281,16 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { auto primt = std::make_unique(); if (primt == nullptr) { MS_LOG(ERROR) << "primt is nullptr"; + FreeTensors(&input_tensors); + FreeTensors(&output_tensors); return false; } *primt = *origin_primt; auto primc = std::shared_ptr(lite::PrimitiveC::Create(primt.release())); if (primc == nullptr) { MS_LOG(ERROR) << "primc is nullptr"; + FreeTensors(&input_tensors); + FreeTensors(&output_tensors); return false; } status = primc->InferShape(input_tensors, output_tensors); diff --git a/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.cc b/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.cc index 296e6ab946..3a6a4e8fa4 100644 --- a/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.cc @@ -16,6 +16,8 @@ #include "tools/optimizer/graph/slice_prepose_pass.h" #include #include +#include +#include #include "mindspore/lite/include/errorcode.h" #include "tools/optimizer/common/gllo_utils.h" #include "backend/optimizer/common/helper.h" @@ -26,6 +28,7 @@ using mindspore::lite::PrimitiveC; namespace mindspore::opt { namespace { +const int kArithmeticInputNum = 2; std::vector GetCNodeInputShape(const CNodePtr &cnode, size_t index = 1) { MS_ASSERT(cnode != nullptr); std::vector empty_shape; @@ -54,9 +57,36 @@ std::vector GetCNodeInputShape(const CNodePtr &cnode, size_t index = 1) } return param_value_lite->tensor_shape(); } -} // namespace -schema::SliceT *SlicePreposePass::GetSliceT(const CNodePtr &cnode) { +std::vector GetDefaultParamShape(const ParameterPtr ¶m) { + MS_ASSERT(param != nullptr); + MS_ASSERT(param->has_default()); + std::vector shape; + auto default_param = param->default_param(); + if (default_param == nullptr) { + MS_LOG(ERROR) << "default_param is nullptr"; + return shape; + } + if (!utils::isa(default_param)) { + MS_LOG(ERROR) << "default_param is not ParamValueLite"; + return shape; + } + auto param_value_lite = utils::cast(default_param); + return param_value_lite->tensor_shape(); +} + +bool IsScalarNode(const AnfNodePtr &nodePtr) { + if (utils::isa(nodePtr) && nodePtr->cast()->has_default()) { + auto tensor = utils::cast(utils::cast(nodePtr)->default_param()); + auto shape = tensor->tensor_shape(); + if (shape.empty() || (shape.size() == 1 && shape[0] == 1)) { + return true; + } + } + return false; +} + +schema::SliceT *GetSliceT(const CNodePtr &cnode) { if (cnode == nullptr) { return nullptr; } @@ -71,6 +101,62 @@ schema::SliceT *SlicePreposePass::GetSliceT(const CNodePtr &cnode) { return primt->value.AsSlice(); } +schema::SoftMaxT *GetSoftmaxT(const CNodePtr &cnode) { + if (cnode == nullptr) { + return nullptr; + } + auto primc = GetValueNode>(cnode->input(0)); + if (primc == nullptr) { + return nullptr; + } + auto primt = primc->GetPrimitiveT(); + if (primt == nullptr || primt->value.AsSoftMax() == nullptr) { + return nullptr; + } + return primt->value.AsSoftMax(); +} + +schema::ReshapeT *GetReshapeT(const CNodePtr &cnode) { + if (cnode == nullptr) { + return nullptr; + } + auto primc = GetValueNode>(cnode->input(0)); + if (primc == nullptr) { + return nullptr; + } + auto primt = primc->GetPrimitiveT(); + if (primt == nullptr || primt->value.AsReshape() == nullptr) { + return nullptr; + } + return primt->value.AsReshape(); +} + +schema::FullConnectionT *GetFcT(const CNodePtr &cnode) { + if (cnode == nullptr) { + return nullptr; + } + auto primc = GetValueNode>(cnode->input(0)); + if (primc == nullptr) { + return nullptr; + } + auto primt = primc->GetPrimitiveT(); + if (primt == nullptr || primt->value.AsFullConnection() == nullptr) { + return nullptr; + } + return primt->value.AsFullConnection(); +} +} // namespace + +void SlicePreposePass::ClearCNodeAbstractValue(const CNodePtr &cnode) { + MS_ASSERT(cnode != nullptr); + auto abstract = cnode->abstract(); + MS_ASSERT(abstract != nullptr); + if (!utils::isa(abstract)) { + MS_LOG(DEBUG) << "Abstract of cnode is not abstract tensor, " << cnode->fullname_with_scope(); + } + abstract->set_value(std::make_shared()); +} + STATUS SlicePreposePass::SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode, const int index, const TransactionPtr &tr) { @@ -107,6 +193,558 @@ STATUS SlicePreposePass::SwapSliceWithPreceed(const FuncGraphPtr &graph, const C return RET_OK; } +ValueNodePtr SlicePreposePass::CreateSliceValueNode(const FuncGraphPtr &graph, const std::vector &axes, + const std::vector &begin, + const std::vector &size) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new SliceT failed"; + return nullptr; + } + attr->axes = axes; + attr->begin = begin; + attr->size = size; + auto new_primitive_t = std::make_unique(); + if (new_primitive_t == nullptr) { + MS_LOG(ERROR) << "primitive_t is nullptr"; + return nullptr; + } + new_primitive_t->value.type = schema::PrimitiveType_Slice; + new_primitive_t->value.value = attr.release(); + auto new_primtive_c = std::shared_ptr(PrimitiveC::Create(new_primitive_t.release())); + if (new_primtive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + return nullptr; + } + ValueNodePtr value_node = NewValueNode(new_primtive_c); + return value_node; +} + +ValueNodePtr SlicePreposePass::CopySliceValueNode(const FuncGraphPtr &graph, const CNodePtr &slice_cnode) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + auto primitive_c = GetValueNode>(slice_cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + return nullptr; + } + auto primitive_t = primitive_c->GetPrimitiveT(); + auto new_primitive_t = std::make_unique(); + if (new_primitive_t == nullptr) { + MS_LOG(ERROR) << "primitive_t is nullptr"; + return nullptr; + } + *new_primitive_t = *primitive_t; + auto new_primitive_c = std::make_shared(new_primitive_t.release()); + if (new_primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + return nullptr; + } + ValueNodePtr value_node = NewValueNode(new_primitive_c); + return value_node; +} + +CNodePtr SlicePreposePass::InsertSlice(const FuncGraphPtr &graph, const ValueNodePtr &slice_vnode, + const CNodePtr &preceed_cnode, const int index, const TransactionPtr &tr) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(preceed_cnode != nullptr); + auto slice_cnode = graph->NewCNode({slice_vnode, preceed_cnode->input(index)}); + tr->SetEdge(preceed_cnode, index, slice_cnode); + return slice_cnode; +} + +STATUS SlicePreposePass::VerifySliceAttrs(const CNodePtr &slice_cnode, const int dim) { + // according to ops/slice.cc, axes >= 0, begin >= 0, size >= -1 + schema::SliceT *slice_t = GetSliceT(slice_cnode); + if (slice_t == nullptr) { + MS_LOG(ERROR) << "SliceT* is nullptr"; + return RET_ERROR; + } + auto &axes = slice_t->axes; + auto &begin = slice_t->begin; + auto &size = slice_t->size; + + std::set unique_axes(axes.begin(), axes.end()); + if (axes.empty() || unique_axes.size() != axes.size()) { + MS_LOG(DEBUG) << "Invalid slice axe attribute"; + return RET_ERROR; + } + for (size_t i = 0; i < axes.size(); ++i) { + int axe = axes[i]; + if (dim > -1 && axe >= dim) { + MS_LOG(ERROR) << "Invalid slice axe attribute"; + return RET_ERROR; + } + if (axe < 0) { + MS_LOG(ERROR) << "Invalid slice axe attribute"; + return RET_ERROR; + } + if (begin[i] < 0) { // we not require begin[i] < ref_shape[axe], cause there may be broadcast + MS_LOG(ERROR) << "Invalid begin input! begin[" << i << "]=" << begin[i]; + return RET_ERROR; + } + if (size[i] < -1) { + MS_LOG(ERROR) << "Invalid size input! size[" << i << "]=" << size[i]; + return RET_ERROR; + } + } + return RET_OK; +} + +/* + * Adjust slice's attr when broadcast happened in Arithmetic + */ +STATUS SlicePreposePass::SliceParamDeBroadcast(const CNodePtr &slice_cnode, const std::vector &ref_shape, + std::vector *axes, std::vector *begin, + std::vector *size) { + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(new_slice_cnode != nullptr); + auto slice_t = GetSliceT(slice_cnode); + if (slice_t == nullptr) { + MS_LOG(ERROR) << "slice_t is nullptr"; + return RET_ERROR; + } + auto origin_axes = slice_t->axes; + auto origin_begin = slice_t->begin; + auto origin_size = slice_t->size; + auto status = VerifySliceAttrs(slice_cnode, ref_shape.size()); + if (status != RET_OK) { + return status; + } + axes->resize(ref_shape.size()); + std::iota(axes->begin(), axes->end(), 0); + begin->assign(ref_shape.size(), 0); + size->assign(ref_shape.size(), -1); + bool real_slice = false; // whether slice happened at this input + for (size_t i = 0; i < origin_axes.size(); ++i) { + int a = origin_axes[i]; + int b = origin_begin[i]; + int s = origin_size[i]; + int ref = ref_shape[a]; + if (ref == 1) { // broadcast + continue; // sliced size is 0(such as begin=1,size=-1) is not considered. + } else if (ref > 1) { // not broadcast + if (b >= ref) { + MS_LOG(ERROR) << "slice begin[" << a << "]=" << b << ", while ref_shape[" << a << "]=" << ref << ", can't fit!"; + return RET_ERROR; + } else { + if (b != 0 || (s != -1 && s != ref)) { + real_slice = true; + } + begin->at(a) = b; + size->at(a) = s; + } + } else { // ref == 0, not need slice + continue; + } + } + if (real_slice) { + return lite::RET_OK; + } else { + return lite::RET_NO_CHANGE; + } +} + +CNodePtr SlicePreposePass::CreateReshapeCNode(const FuncGraphPtr &graph, const std::vector &shape, + const AbstractBasePtr &abstract, const CNodePtr &preceed_cnode) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new SliceT failed"; + return nullptr; + } + attr->shape = shape; + auto new_primitive_t = std::make_unique(); + if (new_primitive_t == nullptr) { + MS_LOG(ERROR) << "primitive_t is nullptr"; + return nullptr; + } + new_primitive_t->value.type = schema::PrimitiveType_Reshape; + new_primitive_t->value.value = attr.release(); + auto new_primtive_c = std::shared_ptr(PrimitiveC::Create(new_primitive_t.release())); + if (new_primtive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + return nullptr; + } + ValueNodePtr value_node = NewValueNode(new_primtive_c); + if (value_node == nullptr) { + return nullptr; + } + auto reshape_cnode = graph->NewCNode({value_node, preceed_cnode}); + reshape_cnode->set_abstract(abstract); + ClearCNodeAbstractValue(reshape_cnode); + return reshape_cnode; +} + +bool SlicePreposePass::SiblingsAreSameSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &output_node_list, + const std::vector &ref_shape) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(output_node_list != nullptr); + MS_ASSERT(output_node_list->size() >= 2); + std::vector slices; + for (auto &output_node : *(output_node_list.get())) { + auto cnode = output_node.first->cast(); + if (cnode == nullptr) { + MS_LOG(ERROR) << "cnode is nullptr"; + return false; + } + if (GetCNodeType(cnode) != schema::PrimitiveType_Slice) { + return false; + } + schema::SliceT *slice_t = GetSliceT(cnode); + if (slice_t == nullptr) { + MS_LOG(ERROR) << "SliceT* is nullptr"; + return false; + } + slices.push_back(slice_t); + } + + auto first_slice_t = slices.front(); + auto first_axes = first_slice_t->axes; + auto first_begin = first_slice_t->begin; + auto first_size = first_slice_t->size; + for (size_t i = 1; i < output_node_list->size(); ++i) { + auto slice_t = slices[i]; + auto axes = slice_t->axes; + auto begin = slice_t->begin; + auto size = slice_t->size; + if (axes.size() != first_axes.size()) { + return false; + } + for (size_t j = 0; j < axes.size(); ++j) { + auto axe = axes[j]; + if (!ref_shape.empty() && axe >= static_cast(ref_shape.size())) { + return false; + } + size_t k = 0; + for (; k < first_axes.size(); ++k) { // axes may not be [0...n-1], so we use nested loop to find it + if (first_axes[k] == axe) { + break; + } + } + if (k == first_axes.size()) { + return false; + } + if (begin[j] != first_begin[k]) { + return false; + } + if (size[j] != first_size[k]) { + if (ref_shape.empty()) { + return false; + } + auto actual_size = size[j] > 0 ? size[j] : ref_shape[axe] - begin[j]; + auto actual_first_size = first_size[k] > 0 ? first_size[k] : ref_shape[axe] - first_begin[k]; + if (actual_size != actual_first_size) { + return false; + } + } + } + } + return true; +} + +int SlicePreposePass::GetReshapeAbnormalAxeIn(const std::vector &shape_in, const std::vector &shape_out, + std::vector *mapped_axe) { + // find shape_out's correspond axe in shape_in + // when there are such as 3x1x1x4 => 3x1x4, mapped_axe[1] == 2 + int32_t inner_size_in = 1; + int abnormal_axe_in = -1; + for (size_t i = 0; i < shape_in.size(); ++i) { + inner_size_in *= shape_in[i]; + int32_t inner_size_out = 1; + size_t j; + for (j = 0; j < shape_out.size(); ++j) { + inner_size_out *= shape_out[j]; + if (shape_out[j] == shape_in[i] && inner_size_out == inner_size_in) { + mapped_axe->at(j) = i; + break; + } + } + if (j == shape_out.size() && abnormal_axe_in == -1) { + abnormal_axe_in = i; + } + } + return abnormal_axe_in; +} + +int SlicePreposePass::GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode, const std::vector &mapped_axe, + const std::vector &shape_out, std::vector *shape_out_copy, + bool *is_normal_mode, bool *support_abnormal_mode) { + MS_ASSERT(slice_cnode != nullptr); + auto slice_t = GetSliceT(slice_cnode); + if (slice_t == nullptr) { + MS_LOG(ERROR) << "slice_t is nullptr"; + return false; + } + auto slice_axes = slice_t->axes; + auto slice_begin = slice_t->begin; + auto slice_size = slice_t->size; + int abnormal_index_out = -1; + for (size_t j = 0; j < shape_out.size(); ++j) { + int index = -1; + for (size_t i = 0; i < slice_axes.size(); ++i) { + if (slice_axes[i] == static_cast(j)) { + index = i; + break; + } + } + if (index == -1) continue; + if (slice_begin[index] != 0 || (slice_size[index] != -1 && slice_size[index] != shape_out[j])) { + if (mapped_axe[j] == -1) { + if (is_normal_mode) { + *is_normal_mode = false; + abnormal_index_out = index; + } else { + *support_abnormal_mode = false; + } + } else { // if there is matched axe sliced, not support abnormal mode + shape_out_copy->at(j) = (slice_size[index] == -1 ? shape_out[j] - slice_begin[index] : slice_size[index]); + *support_abnormal_mode = false; + } + } + } + return abnormal_index_out; +} + +bool SlicePreposePass::PreposeWithNormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &reshape_cnode, const std::vector &shape_in, + const std::vector &shape_out_copy, + const std::vector &mapped_axe) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(reshape_cnode != nullptr); + auto slice_t = GetSliceT(slice_cnode); + if (slice_t == nullptr) { + MS_LOG(ERROR) << "slice_t is nullptr"; + return false; + } + auto slice_axes = slice_t->axes; + auto slice_begin = slice_t->begin; + auto slice_size = slice_t->size; + std::vector new_axes(shape_in.size()); + std::iota(new_axes.begin(), new_axes.end(), 0); + std::vector new_begin(shape_in.size(), 0); + std::vector new_size(shape_in.size(), -1); + + for (size_t i = 0; i < mapped_axe.size(); ++i) { + auto axe_in = mapped_axe[i]; + if (axe_in == -1) { + continue; + } + new_begin[axe_in] = slice_begin[i]; + new_size[axe_in] = slice_size[i]; + } + + auto reshape_t = GetReshapeT(reshape_cnode); + if (reshape_t == nullptr) { + MS_LOG(ERROR) << "reshape_t is nullptr"; + return false; + } + reshape_t->shape = std::vector(shape_out_copy.begin(), shape_out_copy.end()); + auto reshape_origin_inputs = reshape_cnode->inputs(); + if (reshape_origin_inputs.size() < 2) { + MS_LOG(ERROR) << "Reshape inputs num is illegal"; + return false; + } + reshape_cnode->set_inputs({reshape_origin_inputs[0], reshape_origin_inputs[1]}); + + slice_t->axes = new_axes; + slice_t->begin = new_begin; + slice_t->size = new_size; + auto status = SwapSliceWithPreceed(graph, slice_cnode, reshape_cnode, 1); + if (status != RET_OK) { + return false; + } + reshape_cnode->set_abstract(slice_cnode->abstract()->Clone()); + ClearCNodeAbstractValue(slice_cnode); + return true; +} + +CNodePtr SlicePreposePass::CreateSlice1ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &matmul_cnode, const std::vector &shape_in, + const int abnormal_axe_in, const int count_sliced_axe_in, + const bool slice_at_front) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(matmul_cnode != nullptr); + std::vector new_axes1(shape_in.size()); + std::iota(new_axes1.begin(), new_axes1.end(), 0); + std::vector new_begin1(shape_in.size(), 0); + std::vector new_size1(shape_in.size(), -1); + if (slice_at_front) { + new_begin1[abnormal_axe_in] = count_sliced_axe_in; + } else { + new_size1[abnormal_axe_in] = shape_in[abnormal_axe_in] - count_sliced_axe_in; + } + auto new_slice1 = CreateSliceValueNode(graph, new_axes1, new_begin1, new_size1); + if (new_slice1 == nullptr) { + MS_LOG(ERROR) << "CreateSliceValueNode failed"; + return nullptr; + } + auto new_slice1_cnode = graph->NewCNode({new_slice1, matmul_cnode}); + new_slice1_cnode->set_abstract(slice_cnode->abstract()->Clone()); + ClearCNodeAbstractValue(new_slice1_cnode); + return new_slice1_cnode; +} + +CNodePtr SlicePreposePass::CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &new_reshape1_cnode, + const std::vector &new_shape1, + const int abnormal_axe_in, const int count_sliced_axe_in, + const int count_sliced2, const bool slice_at_front) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(matmul_cnode != nullptr); + std::vector new_axes2(abnormal_axe_in + 1); + std::iota(new_axes2.begin(), new_axes2.end(), 0); + std::vector new_begin2(abnormal_axe_in + 1, 0); + std::vector new_size2(abnormal_axe_in + 1, -1); + if (count_sliced2 > new_shape1[abnormal_axe_in]) { + MS_LOG(WARNING) << "calculation error"; + return nullptr; + } + if (slice_at_front) { + new_begin2[abnormal_axe_in] = new_shape1[abnormal_axe_in] - count_sliced2; + } else { + new_size2[abnormal_axe_in] = count_sliced2; + } + auto new_slice2 = CreateSliceValueNode(graph, new_axes2, new_begin2, new_size2); + if (new_slice2 == nullptr) { + MS_LOG(ERROR) << "CreateSliceValueNode failed"; + return nullptr; + } + auto new_slice2_cnode = graph->NewCNode({new_slice2, new_reshape1_cnode}); + new_slice2_cnode->set_abstract(slice_cnode->abstract()->Clone()); + ClearCNodeAbstractValue(new_slice2_cnode); + return new_slice2_cnode; +} + +bool SlicePreposePass::PreposeWithAbnormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &reshape_cnode, const CNodePtr &matmul_cnode, + const std::vector &shape_in, const std::vector &shape_out, + const int abnormal_axe_in, const int abnormal_index_out) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(reshape_cnode != nullptr); + auto manager = graph->manager(); + auto slice_t = GetSliceT(slice_cnode); + if (slice_t == nullptr) { + MS_LOG(ERROR) << "slice_t is nullptr"; + return false; + } + auto slice_axes = slice_t->axes; + auto slice_begin = slice_t->begin; + auto slice_size = slice_t->size; + auto abnormal_axe_out = slice_axes[abnormal_index_out]; + MS_ASSERT(abnormal_axe_out + 1 < shape_out.size()); + int inter_size_in = 1; + int inter_size_out = 1; + for (auto i = 0; i < abnormal_axe_in; ++i) { + inter_size_in *= shape_in[i]; + } + for (auto i = 0; i < abnormal_axe_out; ++i) { + inter_size_out *= shape_out[i]; + } + if (inter_size_in != inter_size_out) { + MS_LOG(DEBUG) << "not support prepose now"; + return false; + } + int outer_size_in = 1; + int outer_size_out = 1; + for (auto i = abnormal_axe_in + 1; i < static_cast(shape_in.size()); ++i) { + outer_size_in *= shape_in[i]; + } + for (auto i = abnormal_axe_out + 1; i < static_cast(shape_out.size()); ++i) { + outer_size_out *= shape_out[i]; + } + const int count_sliced_axe_front = slice_begin[abnormal_index_out]; + const int count_sliced_axe_rear = + slice_size[abnormal_index_out] == -1 ? 0 : (shape_out[abnormal_axe_out] - slice_size[abnormal_index_out]); + if (count_sliced_axe_front * count_sliced_axe_rear > 0) { + MS_LOG(DEBUG) << "not border slice at abnormal axe, prepose with reshape failed"; + return false; + } + bool slice_at_front = count_sliced_axe_front > 0; + const int count_sliced_out = (count_sliced_axe_front + count_sliced_axe_rear) * outer_size_out; + const int count_sliced_axe_in = count_sliced_out / outer_size_in; + if (count_sliced_axe_in <= 0 || count_sliced_axe_in > shape_in[abnormal_axe_in]) { + MS_LOG(DEBUG) << "amount of sliced out tensor is illegal"; + return false; + } + // new_slice1 + auto new_slice1_cnode = CreateSlice1ForReshapePrepose(graph, slice_cnode, matmul_cnode, shape_in, abnormal_axe_in, + count_sliced_axe_in, slice_at_front); + if (new_slice1_cnode == nullptr) { + return false; + } + // new_reshape1 + std::vector new_shape1(abnormal_axe_in + 1); + for (int i = 0; i < abnormal_axe_in; ++i) { + new_shape1[i] = shape_in[i]; + } + new_shape1[abnormal_axe_in] = outer_size_in * (shape_in[abnormal_axe_in] - count_sliced_axe_in); + auto new_reshape1_cnode = CreateReshapeCNode(graph, new_shape1, slice_cnode->abstract()->Clone(), new_slice1_cnode); + if (new_reshape1_cnode == nullptr) { + return false; + } + // new_slice2 + const int count_sliced_abnormal_axe = shape_out[abnormal_axe_out] - (count_sliced_axe_front + count_sliced_axe_rear); + const int count_sliced2 = count_sliced_abnormal_axe * outer_size_out; + auto new_slice2_cnode = + CreateSlice2ForReshapePrepose(graph, slice_cnode, new_reshape1_cnode, new_shape1, abnormal_axe_in, + count_sliced_axe_in, count_sliced2, slice_at_front); + if (new_slice2_cnode == nullptr) { + return false; + } + // new_reshape2 + std::vector new_shape2(shape_out.begin(), shape_out.end()); + new_shape2[abnormal_axe_out] = count_sliced_abnormal_axe; + auto new_reshape2_cnode = CreateReshapeCNode(graph, new_shape2, slice_cnode->abstract()->Clone(), new_slice2_cnode); + if (new_reshape2_cnode == nullptr) { + return false; + } + new_reshape2_cnode->set_abstract(slice_cnode->abstract()->Clone()); + auto node_users = manager->node_users()[slice_cnode]; + for (auto &node_user : node_users) { + manager->SetEdge(node_user.first, node_user.second, new_reshape2_cnode); + } + return true; +} + +bool SlicePreposePass::GetArithmeticInputInfo(const CNodePtr &arithmetic_cnode, std::vector *inputs, + std::vector> *shapes, + std::vector *is_default_params) { + MS_ASSERT(arithmetic_cnode != nullptr); + for (size_t i = 1; i < arithmetic_cnode->inputs().size(); ++i) { + auto input = arithmetic_cnode->input(i); + MS_ASSERT(input != nullptr); + std::vector shape; + if (utils::isa(input)) { + auto parameter = utils::cast(input); + if (!parameter->has_default()) { // if one input is input placeholder, we can't change it + return false; + } else { + shape = GetDefaultParamShape(parameter); + is_default_params->push_back(true); + } + } else { // input is CNode + if (!utils::isa(input)) { + MS_LOG(ERROR) << "one of Arithmetic's input is not CNode"; + return false; + } + shape = GetCNodeInputShape(arithmetic_cnode, i); + is_default_params->push_back(false); + } + inputs->push_back(input); + shapes->push_back(shape); + } + return true; +} + /* * Prepose condition: * the softmax axis is not sliced @@ -116,18 +754,12 @@ bool SlicePreposePass::PreposeWithSoftmax(const FuncGraphPtr &graph, const CNode MS_ASSERT(graph != nullptr); MS_ASSERT(slice_cnode != nullptr); MS_ASSERT(softmax_cnode != nullptr); - auto softmax_primc = GetValueNode>(softmax_cnode->input(0)); - if (softmax_primc == nullptr) { - MS_LOG(ERROR) << "softmax_primc is nullptr"; - return false; - } - auto softmax_primt = softmax_primc->GetPrimitiveT(); - if (softmax_primt == nullptr || softmax_primt->value.AsSoftMax() == nullptr) { - MS_LOG(ERROR) << "softmax_primt is nullptr"; + auto softmax_t = GetSoftmaxT(softmax_cnode); + if (softmax_t == nullptr) { + MS_LOG(ERROR) << "softmax_t is nullptr"; return false; } - auto softmax_attr = softmax_primt->value.AsSoftMax(); - auto softmax_axis = softmax_attr->axis; + auto softmax_axis = softmax_t->axis; auto shape = GetCNodeInputShape(softmax_cnode, 1); if (softmax_axis == -1) { if (shape.empty()) { // when softmax axis == -1, shape info is needed to determine whether slice can be preposed @@ -137,7 +769,9 @@ bool SlicePreposePass::PreposeWithSoftmax(const FuncGraphPtr &graph, const CNode } auto slice_t = GetSliceT(slice_cnode); - MS_ASSERT(slice_t != nullptr); + if (slice_t == nullptr) { + return false; + } auto slice_axes = slice_t->axes; auto slice_begin = slice_t->begin; auto slice_size = slice_t->size; @@ -158,7 +792,511 @@ bool SlicePreposePass::PreposeWithSoftmax(const FuncGraphPtr &graph, const CNode } } auto status = SwapSliceWithPreceed(graph, slice_cnode, softmax_cnode, 1); - return status == RET_OK; + if (status != RET_OK) { + return false; + } + softmax_cnode->set_abstract(slice_cnode->abstract()->Clone()); + ClearCNodeAbstractValue(slice_cnode); + return true; +} + +/* + * Prepose condition: + * require shape info + * when reshape is normal(memory view is not changed, such as 4x5 reshaped to 4x1x5), can always prepose + * when reshape is abnormal(such as 4x5 reshaped to 5x4), can prepose under some constraint + * For abnormal mode: + * we only support border(not slice at center) slice at first mismatch axe, + * and we only support matmul->reshape->slice => matmul->slice->reshape*->slice*(drop "dead" data)->reshape now, + * cause the performance influence introduced by additional (reshape*->slice*) has not been fully evaluated. + */ +bool SlicePreposePass::PreposeWithReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &reshape_cnode) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(reshape_cnode != nullptr); + auto shape_in = GetCNodeInputShape(reshape_cnode, 1); + auto shape_out = GetCNodeInputShape(slice_cnode, 1); + auto shape_out_copy = shape_out; + if (shape_in.empty() || shape_out.empty()) { + MS_LOG(DEBUG) << "Reshape can't be preposed if either input or output shape is unknown"; + return false; + } + if (reshape_cnode->inputs().size() == 3 && utils::isa(reshape_cnode->input(2))) { + auto reshape_input_shape = utils::cast(reshape_cnode->input(2)); + if (!reshape_input_shape->has_default()) { + MS_LOG(ERROR) << "Reshape input shape is not constant"; + return false; + } + } + std::vector mapped_axe(shape_out.size(), -1); + int abnormal_axe_in = GetReshapeAbnormalAxeIn(shape_in, shape_out, &mapped_axe); + bool is_normal_mode = true; // if all sliced axe can be found in input shape, normal + bool support_abnormal_mode = true; // if first mismatch axe are sliced and no more other axes are sliced, abnormal + int abnormal_index_out = GetReshapeAbnormalIndexOut(slice_cnode, mapped_axe, shape_out, &shape_out_copy, + &is_normal_mode, &support_abnormal_mode); + if (is_normal_mode) { + return PreposeWithNormalReshape(graph, slice_cnode, reshape_cnode, shape_in, shape_out_copy, mapped_axe); + } else if (support_abnormal_mode) { + auto matmul_node = reshape_cnode->input(1); + MS_ASSERT(matmul_node != nullptr); + if (IsMultiOutputTensors(graph, matmul_node) || !utils::isa(matmul_node)) { + MS_LOG(DEBUG) << "not matmul->reshape->slice"; + return false; + } + auto matmul_cnode = matmul_node->cast(); + if (matmul_cnode == nullptr) { + MS_LOG(ERROR) << "matmul_cnode is nullptr"; + return false; + } + if (GetCNodeType(matmul_cnode) != schema::PrimitiveType_FullConnection && + GetCNodeType(matmul_cnode) != schema::PrimitiveType_MatMul) { + MS_LOG(DEBUG) << "not matmul->reshape->slice pattern"; + return false; + } + return PreposeWithAbnormalReshape(graph, slice_cnode, reshape_cnode, matmul_cnode, shape_in, shape_out, + abnormal_axe_in, abnormal_index_out); + } + return false; +} + +/* + * Prepose condition: + * require shape info + */ +bool SlicePreposePass::PreposeWithMatmul(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &matmul_cnode) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(matmul_cnode != nullptr); + auto matmul_shape = GetCNodeInputShape(slice_cnode, 1); + const int dims = matmul_shape.size(); + if (dims == 0) { + // if Matmul's output shape is unknown, can't do prepose, cause we can't determine last two axes + return false; + } + auto slice_t = GetSliceT(slice_cnode); + if (slice_t == nullptr) { + MS_LOG(ERROR) << "slice_t is nullptr"; + return RET_ERROR; + } + auto axes = slice_t->axes; + auto begin = slice_t->begin; + auto size = slice_t->size; + // matmul not support broadcast now, it makes things simpler + auto manager = graph->manager(); + std::shared_ptr tr = std::make_shared(manager.get()); + if (tr == nullptr) { + MS_LOG(ERROR) << "create FuncGraphTransaction failed"; + return false; + } + auto node_users = manager->node_users()[slice_cnode]; + bool changed = false; + + bool prepose_to_left = false; // if only the last axe is sliced, not need prepose to left + bool prepose_to_right = false; // if only the second last axe is sliced, not need prepose to right + for (size_t i = 0; i < axes.size(); ++i) { + if (begin[i] != 0 || (size[i] != -1 && size[i] != matmul_shape[axes[i]])) { + if (axes[i] != dims - 1) { + prepose_to_left = true; + } else if (axes[i] != dims - 2) { + prepose_to_right = true; + } + } + } + + if (prepose_to_left) { // left matrix + auto left_axes = axes; + auto left_begin = begin; + auto left_size = size; + for (size_t i = 0; i < left_axes.size(); ++i) { + if (left_axes[i] == dims - 1) { + left_begin[i] = 0; + left_size[i] = -1; + } + } + auto left_slice_vnode = CreateSliceValueNode(graph, left_axes, left_begin, left_size); + if (left_slice_vnode == nullptr) { + MS_LOG(ERROR) << "CreateSliceValueNode failed"; + return false; + } + auto new_slice_cnode = InsertSlice(graph, left_slice_vnode, matmul_cnode, 1, tr); + new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone()); + ClearCNodeAbstractValue(new_slice_cnode); + changed = true; + } + if (prepose_to_right) { // right matrix + auto right_axes = axes; + auto right_begin = begin; + auto right_size = size; + for (size_t i = 0; i < right_axes.size(); ++i) { + if (right_axes[i] == dims - 2) { + right_begin[i] = 0; + right_size[i] = -1; + } + } + auto right_slice_vnode = CreateSliceValueNode(graph, right_axes, right_begin, right_size); + if (right_slice_vnode == nullptr) { + MS_LOG(ERROR) << "CreateSliceValueNode failed"; + return false; + } + auto new_slice_cnode = InsertSlice(graph, right_slice_vnode, matmul_cnode, 2, tr); + new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone()); + ClearCNodeAbstractValue(new_slice_cnode); + changed = true; + } + if (changed) { + matmul_cnode->set_abstract(slice_cnode->abstract()->Clone()); + for (auto &node_user : node_users) { + tr->SetEdge(node_user.first, node_user.second, matmul_cnode); + } + tr->Commit(); + // we don't need graph->DropNode(slice_cnode); + } + return changed; +} + +/* + * Prepose condition: + * require shape info + * only support slice at first output axe now, and useAxis must be false + */ +bool SlicePreposePass::PreposeWithFullConnection(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &fc_cnode) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(fc_cnode != nullptr); + auto shape_in = GetCNodeInputShape(fc_cnode, 1); + auto shape_out = GetCNodeInputShape(slice_cnode, 1); + if (shape_in.empty() || shape_out.size() != 2) { + MS_LOG(DEBUG) << "FullConnection can't be preposed if input shape is unknown or output shape is illegal"; + return false; + } + auto fc_t = GetFcT(fc_cnode); + if (fc_t == nullptr || fc_t->useAxis) { + MS_LOG(DEBUG) << "prepose with fc only support useAxis == false currently"; + return false; + } + auto slice_t = GetSliceT(slice_cnode); + if (slice_t == nullptr) { + MS_LOG(ERROR) << "slice_t is nullptr"; + return RET_ERROR; + } + auto axes = slice_t->axes; + auto begin = slice_t->begin; + auto size = slice_t->size; + for (size_t i = 0; i < axes.size(); ++i) { + if (axes[i] == 1) { + if (begin[i] != 0 || (size[i] != -1 && size[i] != shape_out[1])) { + MS_LOG(DEBUG) << "prepose with fc only support first output axe is sliced currently"; + return false; + } + } + } + + std::vector mapped_axe(shape_out.size(), -1); + int32_t inner_size_in = 1; + for (size_t i = 0; i < shape_in.size(); ++i) { + inner_size_in *= shape_in[i]; + int32_t inner_size_out = 1; + for (size_t j = 0; j < shape_out.size(); ++j) { + inner_size_out *= shape_out[j]; + if (shape_out[j] == shape_in[i] && inner_size_out == inner_size_in) { + mapped_axe[j] = i; + break; + } + } + } + if (mapped_axe[0] == -1) { + MS_LOG(DEBUG) << "first axe in output can't find correspond input axe, can't do prepose"; + return false; + } + + std::vector new_axes(shape_in.size()); + std::iota(new_axes.begin(), new_axes.end(), 0); + std::vector new_begin(shape_in.size(), 0); + std::vector new_size(shape_in.size(), -1); + new_begin[mapped_axe[0]] = begin[0]; + new_size[mapped_axe[0]] = size[0]; + auto new_slice_vnode = CreateSliceValueNode(graph, new_axes, new_begin, new_size); + if (new_slice_vnode == nullptr) { + MS_LOG(ERROR) << "CreateSliceValueNode failed"; + return false; + } + + auto manager = graph->manager(); + std::shared_ptr tr = std::make_shared(manager.get()); + if (tr == nullptr) { + MS_LOG(ERROR) << "create FuncGraphTransaction failed"; + return false; + } + auto new_slice_cnode = InsertSlice(graph, new_slice_vnode, fc_cnode, 1, tr); + fc_cnode->set_abstract(slice_cnode->abstract()->Clone()); + new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone()); + ClearCNodeAbstractValue(new_slice_cnode); + + auto node_users = manager->node_users()[slice_cnode]; + for (auto &node_user : node_users) { + tr->SetEdge(node_user.first, node_user.second, fc_cnode); + } + tr->Commit(); + return true; +} + +/* + * Prepose condition: + * not require shape info, can always prepose + */ +bool SlicePreposePass::PreposeWithTranspose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &transpose_cnode) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(transpose_cnode != nullptr); + auto transpose_primc = GetValueNode>(transpose_cnode->input(0)); + if (transpose_primc == nullptr) { + MS_LOG(ERROR) << "transpose_primc is nullptr"; + return false; + } + auto transpose_primt = transpose_primc->GetPrimitiveT(); + if (transpose_primt == nullptr || transpose_primt->value.AsTranspose() == nullptr) { + MS_LOG(ERROR) << "transpose_primt is nullptr"; + return false; + } + auto transpose_attr = transpose_primt->value.AsTranspose(); + auto perm = transpose_attr->perm; + + auto slice_t = GetSliceT(slice_cnode); + if (slice_t == nullptr) { + MS_LOG(ERROR) << "GetSlicT failed"; + return false; + } + auto old_axes = slice_t->axes; + auto old_begin = slice_t->begin; + auto old_size = slice_t->size; + auto &slice_begin = slice_t->begin; + auto &slice_size = slice_t->size; + // perm is random shuffle of [0...n-1] according to ops/transpose.cc + for (size_t i = 0; i < perm.size(); ++i) { + if (perm[i] != static_cast(i)) { + for (size_t j = 0; j < old_axes.size(); ++j) { + if (old_axes[j] == static_cast(i)) { + slice_begin[perm[i]] = old_begin[j]; + slice_size[perm[i]] = old_size[j]; + break; + } + } + } + } + auto status = SwapSliceWithPreceed(graph, slice_cnode, transpose_cnode, 1); + if (status != RET_OK) { + return false; + } + transpose_cnode->set_abstract(slice_cnode->abstract()->Clone()); + ClearCNodeAbstractValue(slice_cnode); + return true; +} +/* + * Prepose condition: + * may or may not require shape info + */ +bool SlicePreposePass::PreposeWithArithmetic(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &arithmetic_cnode) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slice_cnode != nullptr); + MS_ASSERT(arithmetic_cnode != nullptr); + auto manager = graph->manager(); + auto node_users = manager->node_users()[slice_cnode]; + std::shared_ptr tr = std::make_shared(manager.get()); + if (tr == nullptr) { + MS_LOG(ERROR) << "create FuncGraphTransaction failed"; + return false; + } + bool changed = false; + std::vector inputs; + std::vector> shapes; + std::vector is_default_params; + if (!GetArithmeticInputInfo(arithmetic_cnode, &inputs, &shapes, &is_default_params)) { + return false; + } + + for (size_t i = 1; i < arithmetic_cnode->inputs().size(); ++i) { + auto &input = inputs[i - 1]; + if (IsScalarNode(input)) { // scalar not need prepose + continue; + } + auto &shape = shapes[i - 1]; + const size_t another_index = kArithmeticInputNum - i; + auto &another_input = inputs[another_index]; + auto &another_shape = shapes[another_index]; + if (IsScalarNode(input)) { + continue; + } else if (shape.empty()) { // infershape failed at this input + if (IsScalarNode(another_input)) { // if another input is scalar, we can process this one + auto new_slice_vnode = CopySliceValueNode(graph, slice_cnode); + if (new_slice_vnode == nullptr) { + changed = false; + break; + } + auto new_slice_cnode = InsertSlice(graph, new_slice_vnode, arithmetic_cnode, i, tr); + new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone()); + ClearCNodeAbstractValue(new_slice_cnode); + changed = true; + break; + } else { // if another input's shape is not scalar, can't be processed + changed = false; + break; + } + } else { // shape not empty + if (!another_shape.empty() || IsScalarNode(another_input)) { + std::vector new_axes; + std::vector new_begin; + std::vector new_size; + auto status = SliceParamDeBroadcast(slice_cnode, shape, &new_axes, &new_begin, &new_size); + if (status == lite::RET_NO_CHANGE) { + continue; + } + if (status != lite::RET_OK) { + changed = false; + break; + } + auto new_slice_vnode = CreateSliceValueNode(graph, new_axes, new_begin, new_size); + if (new_slice_vnode == nullptr) { + changed = false; + break; + } + auto new_slice_cnode = InsertSlice(graph, new_slice_vnode, arithmetic_cnode, i, tr); + new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone()); + ClearCNodeAbstractValue(new_slice_cnode); + changed = true; + } else { + changed = false; + break; + } + } + } + if (changed) { + arithmetic_cnode->set_abstract(slice_cnode->abstract()->Clone()); + for (auto &node_user : node_users) { + tr->SetEdge(node_user.first, node_user.second, arithmetic_cnode); + } + tr->Commit(); + // we don't need graph->DropNode(slice_cnode); + } + return changed; +} // namespace mindspore::opt +/* + * Prepose condition: + * not require shape info + */ +bool SlicePreposePass::MergeSequentialSlice(const FuncGraphPtr &graph, const CNodePtr &slice1_cnode, + const CNodePtr &slice2_cnode) { + if (slice2_cnode->inputs().size() != lite::kDoubleNum) { + MS_LOG(INFO) << "Slice read attrs from input is not supported now"; + return false; + } + auto slice1_t = GetSliceT(slice1_cnode); // bottom node + auto slice2_t = GetSliceT(slice2_cnode); // top node + if (slice1_t == nullptr || slice2_t == nullptr) { + MS_LOG(ERROR) << "slice_t is null"; + return false; + } + auto begin_slice1 = slice1_t->begin; + auto size_slice1 = slice1_t->size; + auto axes_slice1 = slice1_t->axes; + auto begin_slice2 = slice2_t->begin; + auto size_slice2 = slice2_t->size; + auto axes_slice2 = slice2_t->axes; + auto status1 = VerifySliceAttrs(slice1_cnode); + auto status2 = VerifySliceAttrs(slice2_cnode); + if (status1 != RET_OK || status2 != RET_OK) { + return false; + } + + auto manager = graph->manager(); + auto node_users = manager->node_users()[slice1_cnode]; + int axe_max1 = *std::max_element(axes_slice1.begin(), axes_slice1.end()); + int axe_max2 = *std::max_element(axes_slice2.begin(), axes_slice2.end()); + int axe_max = std::max(axe_max1, axe_max2); + auto &begin_new = slice2_t->begin; + auto &size_new = slice2_t->size; + auto &axes_new = slice2_t->axes; + axes_new.resize(axe_max + 1); + std::iota(axes_new.begin(), axes_new.end(), 0); + begin_new.assign(axe_max + 1, 0); + size_new.assign(axe_max + 1, -1); + for (int i = 0; i <= axe_max; ++i) { + for (size_t j = 0; j < axes_slice2.size(); ++j) { + if (axes_slice2[j] == i) { + begin_new[i] = begin_slice2[j]; + size_new[i] = size_slice2[j]; + break; + } + } + for (size_t j = 0; j < axes_slice1.size(); ++j) { + if (axes_slice1[j] == i) { + begin_new[i] = begin_new[i] + begin_slice1[j]; + if (size_new[i] == -1) { + size_new[i] = size_slice1[j]; + } else { + if (size_slice1[j] == -1) { + size_new[i] = std::max(size_new[i] - begin_slice1[i], 0); // clip with zero to avoid invalid negative value + } else { + size_new[i] = std::max(std::min(size_new[i] - begin_slice1[j], size_slice1[j]), 0); + } + } + break; + } + } + } + slice2_cnode->set_abstract(slice1_cnode->abstract()->Clone()); + for (auto &node_user : node_users) { + manager->SetEdge(node_user.first, node_user.second, slice2_cnode); + } + return true; +} + +/* + * Prepose condition: + * when all sibling slices do same work + * can be optimize to not require all siblings are slice + */ +bool SlicePreposePass::MergeParallelSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &slices) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(slices->size() >= 2); + auto manager = graph->manager(); + auto first_slice = utils::cast(slices->at(0).first); + if (first_slice == nullptr || GetCNodeType(first_slice) != schema::PrimitiveType_Slice) { + MS_LOG(ERROR) << "first node is not Slice"; + return false; + } + auto first_parent = first_slice->input(1); + if (first_parent == nullptr) { + MS_LOG(ERROR) << "first slice node's parent is nullptr"; + return false; + } + std::shared_ptr tr = std::make_shared(manager.get()); + if (tr == nullptr) { + MS_LOG(ERROR) << "create FuncGraphTransaction failed"; + return false; + } + for (size_t i = 1; i < slices->size(); ++i) { + auto slice = utils::cast(slices->at(i).first); + if (slice == nullptr || GetCNodeType(slice) != schema::PrimitiveType_Slice) { + MS_LOG(ERROR) << "current node is not Slice"; + return false; + } + auto parent = slice->input(1); + if (parent == nullptr || parent != first_parent) { + MS_LOG(ERROR) << "not all slices have same parent node"; + return false; + } + auto node_users = manager->node_users()[slices->at(i).first]; + for (auto &node_user : node_users) { + tr->SetEdge(node_user.first, node_user.second, slices->at(0).first); + } + } + tr->Commit(); + return true; } bool SlicePreposePass::DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, @@ -171,6 +1309,26 @@ bool SlicePreposePass::DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slic case schema::PrimitiveType_SoftMax: { return PreposeWithSoftmax(graph, slice_cnode, preceed_cnode); } + case schema::PrimitiveType_Reshape: { + return PreposeWithReshape(graph, slice_cnode, preceed_cnode); + } + case schema::PrimitiveType_MatMul: { + return PreposeWithMatmul(graph, slice_cnode, preceed_cnode); + } + case schema::PrimitiveType_FullConnection: { + return PreposeWithFullConnection(graph, slice_cnode, preceed_cnode); + } + case schema::PrimitiveType_Transpose: { + return PreposeWithTranspose(graph, slice_cnode, preceed_cnode); + } + case schema::PrimitiveType_Sub: + case schema::PrimitiveType_Mul: + case schema::PrimitiveType_Add: { + return PreposeWithArithmetic(graph, slice_cnode, preceed_cnode); + } + case schema::PrimitiveType_Slice: { + return MergeSequentialSlice(graph, slice_cnode, preceed_cnode); + } default: { MS_LOG(DEBUG) << "Node type " << preceed_node_type << " currently not support SlicePrepose"; } @@ -216,6 +1374,12 @@ bool SlicePreposePass::Run(const FuncGraphPtr &graph) { } auto output_node_list = GetRealNodeUsedList(graph, utils::cast(preceed_node)); if (output_node_list->size() > 1) { // referenced by multi nodes + if (SiblingsAreSameSlice(graph, output_node_list)) { + if (MergeParallelSlice(graph, output_node_list)) { + this_time_changed = true; + break; + } + } continue; } else { if (utils::isa(preceed_node)) { diff --git a/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.h b/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.h index 99ee5b2a9c..52fa09d79a 100644 --- a/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.h +++ b/mindspore/lite/tools/optimizer/graph/slice_prepose_pass.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "tools/converter/converter_flags.h" #include "backend/optimizer/common/pass.h" #include "include/errorcode.h" @@ -40,11 +41,54 @@ class SlicePreposePass : public Pass { void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; } private: - schema::SliceT *GetSliceT(const CNodePtr &cnode); - bool DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode); + void ClearCNodeAbstractValue(const CNodePtr &cnode); STATUS SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode, const int index, const TransactionPtr &tr = nullptr); + ValueNodePtr CreateSliceValueNode(const FuncGraphPtr &graph, const std::vector &axes, + const std::vector &begin, const std::vector &size); + ValueNodePtr CopySliceValueNode(const FuncGraphPtr &graph, const CNodePtr &slice_cnode); + CNodePtr InsertSlice(const FuncGraphPtr &graph, const ValueNodePtr &slice_vnode, const CNodePtr &preceed_cnode, + const int index, const TransactionPtr &tr); + STATUS VerifySliceAttrs(const CNodePtr &slice_cnode, const int dim = -1); + STATUS SliceParamDeBroadcast(const CNodePtr &slice_cnode, const std::vector &ref_shape, + std::vector *axes, std::vector *begin, std::vector *size); + CNodePtr CreateReshapeCNode(const FuncGraphPtr &graph, const std::vector &shape, + const AbstractBasePtr &abstract, const CNodePtr &preceed_cnode); + bool SiblingsAreSameSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &output_node_list, + const std::vector &ref_shape = {}); + int GetReshapeAbnormalAxeIn(const std::vector &shape_in, const std::vector &shape_out, + std::vector *mapped_axe); + int GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode, const std::vector &mapped_axe, + const std::vector &shape_out, std::vector *shape_out_copy, + bool *is_normal_mode, bool *support_abnormal_mode); + bool PreposeWithNormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &reshape_cnode, + const std::vector &shape_in, const std::vector &shape_out_copy, + const std::vector &mapped_axe); + CNodePtr CreateSlice1ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &matmul_cnode, const std::vector &shape_in, + const int abnormal_axe_in, const int count_sliced_axe_in, + const bool slice_at_front); + CNodePtr CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, + const CNodePtr &new_reshape1_cnode, const std::vector &new_shape1, + const int abnormal_axe_in, const int count_sliced_axe_in, + const int count_sliced2, const bool slice_at_front); + bool PreposeWithAbnormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &reshape_cnode, + const CNodePtr &matmul_cnode, const std::vector &shape_in, + const std::vector &shape_out, const int abnormal_axe_in, + const int abnormal_index_out); + bool GetArithmeticInputInfo(const CNodePtr &arithmetic_cnode, std::vector *inputs, + std::vector> *shapes, std::vector *is_default_params); + + bool DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode); + bool PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &softmax_cnode); + bool PreposeWithReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &reshape_cnode); + bool PreposeWithMatmul(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &matmul_cnode); + bool PreposeWithFullConnection(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &fc_cnode); + bool PreposeWithTranspose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &transpose_cnode); + bool PreposeWithArithmetic(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &arithmetic_cnode); + bool MergeSequentialSlice(const FuncGraphPtr &graph, const CNodePtr &slice1_cnode, const CNodePtr &slice2_cnode); + bool MergeParallelSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &slices); private: FmkType fmk_type = lite::converter::FmkType_ONNX;