Set abstracts for tuple_getitems of splitv

pull/7164/head
yujianfeng 5 years ago
parent 7126e316bc
commit 43e7cd2e42

@ -94,6 +94,7 @@ void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int
} }
void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv, void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv,
const std::vector<AnfNodePtr> &base_splitv_outputs,
const std::vector<int> &size_splits_base, int split_dim, int num_split) { const std::vector<int> &size_splits_base, int split_dim, int num_split) {
SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split); SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split);
auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0);
@ -106,6 +107,7 @@ void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePt
for (int i = 0; i < num_split; ++i) { for (int i = 0; i < num_split; ++i) {
output_shape[split_dim] = size_splits_base[i]; output_shape[split_dim] = size_splits_base[i];
base_output_shapes_base.emplace_back(output_shape); base_output_shapes_base.emplace_back(output_shape);
AnfAlgo::SetOutputInferTypeAndShape({type_id}, {output_shape}, base_splitv_outputs[i].get());
} }
AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get()); AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get());
} }
@ -127,11 +129,14 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
// Start to divide the outputs of Split. // Start to divide the outputs of Split.
std::vector<int> size_splits_base; std::vector<int> size_splits_base;
std::vector<AnfNodePtr> base_splitv_outputs;
const auto base_split_size = divisor * small_split_size; const auto base_split_size = divisor * small_split_size;
int nodes_num = 0; int nodes_num = 0;
int cur_output_index = 0; int cur_output_index = 0;
while (num_split - cur_output_index > divisor) { while (num_split - cur_output_index > divisor) {
CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
base_splitv_outputs.push_back(tuple_getitem);
CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem);
SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor); SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor);
AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get()); AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get());
AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs); AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs);
@ -142,7 +147,9 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int
if (cur_output_index < num_split) { if (cur_output_index < num_split) {
auto last_node_num_split = num_split - cur_output_index; auto last_node_num_split = num_split - cur_output_index;
if (last_node_num_split > 1) { if (last_node_num_split > 1) {
CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
base_splitv_outputs.push_back(tuple_getitem);
CNodePtr new_splitv = CreateSplitVNode(func_graph, tuple_getitem);
std::vector<int> size_splits_new_last(last_node_num_split, small_split_size); std::vector<int> size_splits_new_last(last_node_num_split, small_split_size);
SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split); SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split);
// Create new output shape and new output type id for the last Splitv node // Create new output shape and new output type id for the last Splitv node
@ -154,13 +161,15 @@ AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int
AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs); AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs);
size_splits_base.emplace_back(last_node_num_split * small_split_size); size_splits_base.emplace_back(last_node_num_split * small_split_size);
} else { } else {
make_tuple_inputs.emplace_back(CreateTupleGetItem(func_graph, base_splitv, nodes_num)); auto tuple_getitem = CreateTupleGetItem(func_graph, base_splitv, nodes_num);
base_splitv_outputs.push_back(tuple_getitem);
make_tuple_inputs.emplace_back(tuple_getitem);
size_splits_base.emplace_back(small_split_size); size_splits_base.emplace_back(small_split_size);
} }
nodes_num++; nodes_num++;
} }
// Set Attr and abstract for the base splitv // Set Attr and abstract for the base splitv
SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, size_splits_base, split_dim, nodes_num); SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, base_splitv_outputs, size_splits_base, split_dim, nodes_num);
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
return make_tuple; return make_tuple;
} }

Loading…
Cancel
Save