From b5845b6b7b2e363a040a07e7b09e10110330badd Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Wed, 8 Jul 2020 09:31:38 +0800 Subject: [PATCH] Fix the bug of setting shape when the axis is negative in the split fission pass --- .../ccsrc/pre_activate/ascend/ir_fission/split_fission.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc index c39a5e01e6..2ab1cb6130 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc @@ -82,6 +82,9 @@ void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int MS_EXCEPTION_IF_NULL(new_type_ids); MS_EXCEPTION_IF_NULL(new_output_shapes); auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); + if (split_dim < 0) { + split_dim += output_shape.size(); + } output_shape[split_dim] = split_size; TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); for (int i = 0; i < num_split; ++i) { @@ -97,6 +100,9 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt std::vector> base_output_shapes_base; auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); + if (split_dim < 0) { + split_dim += output_shape.size(); + } for (int i = 0; i < num_split; ++i) { output_shape[split_dim] = size_splits_base[i]; base_output_shapes_base.emplace_back(output_shape);