|
|
|
@ -1790,38 +1790,92 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
|
|
|
|
|
auto func_graph = std::make_shared<FuncGraph>();
|
|
|
|
|
func_graph->debug_info()->set_name("top");
|
|
|
|
|
// Generate and copy a ValueNode, or a CNode with its child nodes
|
|
|
|
|
static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr func_graph, const AnfNodePtr ¶m_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(param_node);
|
|
|
|
|
if (param_node->isa<ValueNode>()) {
|
|
|
|
|
return std::make_shared<ValueNode>(param_node->cast<ValueNodePtr>()->value());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Parameter default value is CNode.
|
|
|
|
|
std::size_t index = 0;
|
|
|
|
|
std::vector<AnfNodePtr> old_cnodes;
|
|
|
|
|
old_cnodes.emplace_back(param_node);
|
|
|
|
|
auto res = func_graph->NewCNode({});
|
|
|
|
|
std::vector<CNodePtr> new_cnodes;
|
|
|
|
|
new_cnodes.emplace_back(res);
|
|
|
|
|
while (index < old_cnodes.size()) {
|
|
|
|
|
auto current = old_cnodes[index];
|
|
|
|
|
auto current_new_cnode = new_cnodes[index];
|
|
|
|
|
index++;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(current);
|
|
|
|
|
if (current->isa<CNode>()) {
|
|
|
|
|
auto &inputs = current->cast<CNodePtr>()->inputs();
|
|
|
|
|
for (auto it = inputs.begin(); it != inputs.end(); it++) {
|
|
|
|
|
AnfNodePtr input = *it;
|
|
|
|
|
if (input != nullptr && input->isa<CNode>()) {
|
|
|
|
|
old_cnodes.emplace_back(input);
|
|
|
|
|
auto new_cnode = func_graph->NewCNode({});
|
|
|
|
|
new_cnodes.emplace_back(new_cnode);
|
|
|
|
|
current_new_cnode->add_input(new_cnode);
|
|
|
|
|
} else if (input->isa<ValueNode>()) {
|
|
|
|
|
current_new_cnode->add_input(std::make_shared<ValueNode>(input->cast<ValueNodePtr>()->value()));
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Wrong type item in default parameters: " << input->ToString();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// def top(*arg, *kwargs):
|
|
|
|
|
auto param_vargs = func_graph->add_parameter();
|
|
|
|
|
auto args_name = "args";
|
|
|
|
|
param_vargs->set_name(args_name);
|
|
|
|
|
param_vargs->debug_info()->set_name(args_name);
|
|
|
|
|
FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
|
|
|
|
|
auto current_graph = dyn_cast<FuncGraph>(cell_ptr);
|
|
|
|
|
if (current_graph == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Current graph cast failed from " << cell_ptr->ToString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto param_vkwargs = func_graph->add_parameter();
|
|
|
|
|
args_name = "kwargs";
|
|
|
|
|
param_vkwargs->set_name(args_name);
|
|
|
|
|
param_vkwargs->debug_info()->set_name(args_name);
|
|
|
|
|
auto func_graph = std::make_shared<FuncGraph>();
|
|
|
|
|
func_graph->debug_info()->set_name(current_graph->debug_info()->name() + "_wrapper");
|
|
|
|
|
|
|
|
|
|
func_graph->set_has_vararg(true);
|
|
|
|
|
func_graph->set_has_kwarg(true);
|
|
|
|
|
func_graph->set_kwonlyargs_count(0);
|
|
|
|
|
// Copy all parameters information
|
|
|
|
|
for (auto ¶ : current_graph->parameters()) {
|
|
|
|
|
auto param = func_graph->add_parameter();
|
|
|
|
|
auto orig_param = para->cast<ParameterPtr>();
|
|
|
|
|
auto name = orig_param->name();
|
|
|
|
|
param->set_name(name);
|
|
|
|
|
param->debug_info()->set_name(name);
|
|
|
|
|
}
|
|
|
|
|
func_graph->set_has_vararg(current_graph->has_vararg());
|
|
|
|
|
func_graph->set_has_kwarg(current_graph->has_kwarg());
|
|
|
|
|
func_graph->set_kwonlyargs_count(current_graph->kwonlyargs_count());
|
|
|
|
|
// Copy all default values
|
|
|
|
|
for (auto &d : current_graph->parameter_default_value()) {
|
|
|
|
|
func_graph->set_param_default_value(d.first, CopyNodesFromParamDefaultValue(func_graph, d.second));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// cell_obj
|
|
|
|
|
MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell));
|
|
|
|
|
parse::UpdateFuncGraphFlags(cell, func_graph);
|
|
|
|
|
// top graph's construct flag
|
|
|
|
|
if (py::hasattr(cell, "construct")) {
|
|
|
|
|
parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ret = cell_obj(*arg, *kwargs)
|
|
|
|
|
auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), {param_vargs, param_vkwargs});
|
|
|
|
|
|
|
|
|
|
// return ret
|
|
|
|
|
func_graph->set_output(call_fn);
|
|
|
|
|
MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell));
|
|
|
|
|
auto unpacking = func_graph->has_vararg() || func_graph->has_kwarg();
|
|
|
|
|
if (!unpacking) {
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
inputs.emplace_back(NewValueNode(cell_ptr));
|
|
|
|
|
auto ¶ms = func_graph->parameters();
|
|
|
|
|
(void)std::transform(params.begin(), params.end(), std::back_inserter(inputs),
|
|
|
|
|
[](AnfNodePtr node) -> AnfNodePtr { return node; });
|
|
|
|
|
func_graph->set_output(func_graph->NewCNode(inputs));
|
|
|
|
|
} else {
|
|
|
|
|
// ret = cell_obj(*arg, *kwargs)
|
|
|
|
|
auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), func_graph->parameters());
|
|
|
|
|
// return ret
|
|
|
|
|
func_graph->set_output(call_fn);
|
|
|
|
|
}
|
|
|
|
|
return func_graph;
|
|
|
|
|
}
|
|
|
|
|
} // namespace parse
|
|
|
|
|