|
|
@ -82,6 +82,9 @@ void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int
|
|
|
|
MS_EXCEPTION_IF_NULL(new_type_ids);
|
|
|
|
MS_EXCEPTION_IF_NULL(new_type_ids);
|
|
|
|
MS_EXCEPTION_IF_NULL(new_output_shapes);
|
|
|
|
MS_EXCEPTION_IF_NULL(new_output_shapes);
|
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
|
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
|
|
|
|
|
|
|
|
if (split_dim < 0) {
|
|
|
|
|
|
|
|
split_dim += output_shape.size();
|
|
|
|
|
|
|
|
}
|
|
|
|
output_shape[split_dim] = split_size;
|
|
|
|
output_shape[split_dim] = split_size;
|
|
|
|
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
|
|
|
|
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
|
|
|
|
for (int i = 0; i < num_split; ++i) {
|
|
|
|
for (int i = 0; i < num_split; ++i) {
|
|
|
@ -97,6 +100,9 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt
|
|
|
|
std::vector<std::vector<size_t>> base_output_shapes_base;
|
|
|
|
std::vector<std::vector<size_t>> base_output_shapes_base;
|
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
|
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
|
|
|
|
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
|
|
|
|
TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0);
|
|
|
|
|
|
|
|
if (split_dim < 0) {
|
|
|
|
|
|
|
|
split_dim += output_shape.size();
|
|
|
|
|
|
|
|
}
|
|
|
|
for (int i = 0; i < num_split; ++i) {
|
|
|
|
for (int i = 0; i < num_split; ++i) {
|
|
|
|
output_shape[split_dim] = size_splits_base[i];
|
|
|
|
output_shape[split_dim] = size_splits_base[i];
|
|
|
|
base_output_shapes_base.emplace_back(output_shape);
|
|
|
|
base_output_shapes_base.emplace_back(output_shape);
|
|
|
|