|
|
@ -94,6 +94,7 @@ void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv,
|
|
|
|
void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv,
|
|
|
|
|
|
|
|
const std::vector<AnfNodePtr> &base_splitv_outputs,
|
|
|
|
const std::vector<int> &size_splits_base, int split_dim, int num_split) {
|
|
|
|
const std::vector<int> &size_splits_base, int split_dim, int num_split) {
|
|
|
|
SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split);
|
|
|
|
SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split);
|
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
|
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
|
|
|
@ -106,6 +107,7 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt
|
|
|
|
for (int i = 0; i < num_split; ++i) {
|
|
|
|
for (int i = 0; i < num_split; ++i) {
|
|
|
|
output_shape[split_dim] = size_splits_base[i];
|
|
|
|
output_shape[split_dim] = size_splits_base[i];
|
|
|
|
base_output_shapes_base.emplace_back(output_shape);
|
|
|
|
base_output_shapes_base.emplace_back(output_shape);
|
|
|
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({type_id}, {output_shape}, base_splitv_outputs[i].get());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get());
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get());
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -127,11 +129,14 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
|
|
|
// Start to divide the outputs of Split.
|
|
|
|
// Start to divide the outputs of Split.
|
|
|
|
std::vector<int> size_splits_base;
|
|
|
|
std::vector<int> size_splits_base;
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> base_splitv_outputs;
|
|
|
|
const auto base_split_size = divisor * small_split_size;
|
|
|
|
const auto base_split_size = divisor * small_split_size;
|
|
|
|
int nodes_num = 0;
|
|
|
|
int nodes_num = 0;
|
|
|
|
int cur_output_index = 0;
|
|
|
|
int cur_output_index = 0;
|
|
|
|
while (num_split - cur_output_index > divisor) {
|
|
|
|
while (num_split - cur_output_index > divisor) {
|
|
|
|
CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num));
|
|
|
|
auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
|
|
|
|
|
|
|
|
base_splitv_outputs.push_back(tuple_getitem);
|
|
|
|
|
|
|
|
CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem);
|
|
|
|
SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor);
|
|
|
|
SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor);
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get());
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get());
|
|
|
|
AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs);
|
|
|
|
AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs);
|
|
|
@ -142,7 +147,9 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int
|
|
|
|
if (cur_output_index < num_split) {
|
|
|
|
if (cur_output_index < num_split) {
|
|
|
|
auto last_node_num_split = num_split - cur_output_index;
|
|
|
|
auto last_node_num_split = num_split - cur_output_index;
|
|
|
|
if (last_node_num_split > 1) {
|
|
|
|
if (last_node_num_split > 1) {
|
|
|
|
CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num));
|
|
|
|
auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
|
|
|
|
|
|
|
|
base_splitv_outputs.push_back(tuple_getitem);
|
|
|
|
|
|
|
|
CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem);
|
|
|
|
std::vector<int> size_splits_new_last(last_node_num_split, small_split_size);
|
|
|
|
std::vector<int> size_splits_new_last(last_node_num_split, small_split_size);
|
|
|
|
SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split);
|
|
|
|
SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split);
|
|
|
|
// Create new output shape and new output type id for the last Splitv node
|
|
|
|
// Create new output shape and new output type id for the last Splitv node
|
|
|
@ -154,13 +161,15 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int
|
|
|
|
AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs);
|
|
|
|
AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs);
|
|
|
|
size_splits_base.emplace_back(last_node_num_split * small_split_size);
|
|
|
|
size_splits_base.emplace_back(last_node_num_split * small_split_size);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
make_tuple_inputs.emplace_back(CreateTupleGetItem(func_graph, base_splitv, nodes_num));
|
|
|
|
auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
|
|
|
|
|
|
|
|
base_splitv_outputs.push_back(tuple_getitem);
|
|
|
|
|
|
|
|
make_tuple_inputs.emplace_back(tuple_getitem);
|
|
|
|
size_splits_base.emplace_back(small_split_size);
|
|
|
|
size_splits_base.emplace_back(small_split_size);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
nodes_num++;
|
|
|
|
nodes_num++;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Set Attr and abstract for the base splitv
|
|
|
|
// Set Attr and abstract for the base splitv
|
|
|
|
SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, size_splits_base, split_dim, nodes_num);
|
|
|
|
SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, base_splitv_outputs, size_splits_base, split_dim, nodes_num);
|
|
|
|
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
|
|
|
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
|
|
|
return make_tuple;
|
|
|
|
return make_tuple;
|
|
|
|
}
|
|
|
|
}
|
|
|
|