Set abstracts for tuple_getitems of splitv

pull/7164/head
yujianfeng 5 years ago
parent 7126e316bc
commit 43e7cd2e42

@ -94,6 +94,7 @@ void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int
}
void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv,
const std::vector<AnfNodePtr> &base_splitv_outputs,
const std::vector<int> &size_splits_base, int split_dim, int num_split) {
SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split);
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
@ -106,6 +107,7 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt
for (int i = 0; i < num_split; ++i) {
output_shape[split_dim] = size_splits_base[i];
base_output_shapes_base.emplace_back(output_shape);
AnfAlgo::SetOutputInferTypeAndShape({type_id}, {output_shape}, base_splitv_outputs[i].get());
}
AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get());
}
@ -127,11 +129,14 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
// Start to divide the outputs of Split.
std::vector<int> size_splits_base;
std::vector<AnfNodePtr> base_splitv_outputs;
const auto base_split_size = divisor * small_split_size;
int nodes_num = 0;
int cur_output_index = 0;
while (num_split - cur_output_index > divisor) {
CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num));
auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
base_splitv_outputs.push_back(tuple_getitem);
CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem);
SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor);
AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get());
AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs);
@ -142,7 +147,9 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int
if (cur_output_index < num_split) {
auto last_node_num_split = num_split - cur_output_index;
if (last_node_num_split > 1) {
CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num));
auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
base_splitv_outputs.push_back(tuple_getitem);
CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem);
std::vector<int> size_splits_new_last(last_node_num_split, small_split_size);
SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split);
// Create new output shape and new output type id for the last Splitv node
@ -154,13 +161,15 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int
AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs);
size_splits_base.emplace_back(last_node_num_split * small_split_size);
} else {
make_tuple_inputs.emplace_back(CreateTupleGetItem(func_graph, base_splitv, nodes_num));
auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
base_splitv_outputs.push_back(tuple_getitem);
make_tuple_inputs.emplace_back(tuple_getitem);
size_splits_base.emplace_back(small_split_size);
}
nodes_num++;
}
// Set Attr and abstract for the base splitv
SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, size_splits_base, split_dim, nodes_num);
SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, base_splitv_outputs, size_splits_base, split_dim, nodes_num);
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
return make_tuple;
}

Loading…
Cancel
Save