Fix the bug of setting shape when the axis is negative in the split fission pass

pull/2928/head
yujianfeng 5 years ago
parent 1f4944fa15
commit b5845b6b7b

@ -82,6 +82,9 @@ void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int
MS_EXCEPTION_IF_NULL(new_type_ids);
MS_EXCEPTION_IF_NULL(new_output_shapes);
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
if (split_dim < 0) {
split_dim += output_shape.size();
}
output_shape[split_dim] = split_size;
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
for (int i = 0; i < num_split; ++i) {
@ -97,6 +100,9 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt
std::vector<std::vector<size_t>> base_output_shapes_base;
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
if (split_dim < 0) {
split_dim += output_shape.size();
}
for (int i = 0; i < num_split; ++i) {
output_shape[split_dim] = size_splits_base[i];
base_output_shapes_base.emplace_back(output_shape);

Loading…
Cancel
Save