|
|
|
@ -34,6 +34,8 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_
|
|
|
|
|
new_addn->set_scope(origin_addn_cnode->scope());
|
|
|
|
|
new_addn->set_abstract(origin_addn_cnode->abstract());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn);
|
|
|
|
|
std::vector<int> dyn_input_sizes{SizeToInt(offset)};
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn);
|
|
|
|
|
return new_addn;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
@ -55,22 +57,24 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN
|
|
|
|
|
}
|
|
|
|
|
CNodePtr new_cnode = cnode;
|
|
|
|
|
while (origin_input_size > inputs_divisor_) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_cnode);
|
|
|
|
|
std::vector<AnfNodePtr> base_addn_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))};
|
|
|
|
|
size_t cur_input_index = 1;
|
|
|
|
|
// Divide the inputs of addn by 63.
|
|
|
|
|
while (origin_input_size - cur_input_index + 1 > inputs_divisor_) {
|
|
|
|
|
// Divide the inputs of addn by inputs_divisor_.
|
|
|
|
|
while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) {
|
|
|
|
|
base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_));
|
|
|
|
|
cur_input_index += inputs_divisor_;
|
|
|
|
|
}
|
|
|
|
|
base_addn_inputs.push_back(
|
|
|
|
|
CreateNewAddn(func_graph, new_cnode, cur_input_index, origin_input_size - cur_input_index + 1));
|
|
|
|
|
|
|
|
|
|
for (size_t i = cur_input_index; i <= origin_input_size; i++) {
|
|
|
|
|
base_addn_inputs.push_back(new_cnode->input(i));
|
|
|
|
|
}
|
|
|
|
|
CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(base_addn);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_cnode);
|
|
|
|
|
base_addn->set_scope(new_cnode->scope());
|
|
|
|
|
base_addn->set_abstract(new_cnode->abstract());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn);
|
|
|
|
|
std::vector<int> dyn_input_sizes{SizeToInt(base_addn_inputs.size() - 1)};
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn);
|
|
|
|
|
new_cnode = base_addn;
|
|
|
|
|
origin_input_size = base_addn->inputs().size() - 1;
|
|
|
|
|
}
|
|
|
|
|