!14656 insert u monad parameter before io monad parameter in auto_monad

From: @Margaret_wangrui
Reviewed-by: @zh_qh,@hwhewei
Signed-off-by: @zh_qh
pull/14656/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 4d37f35814

@ -41,8 +41,11 @@ using RefInputs = OrderedMap<AnfNodePtr, std::vector<size_t>>;
// Add or get a monad parameter.
AnfNodePtr AddMonadParameter(const FuncGraphPtr &func_graph, const std::string &name,
const abstract::AbstractBasePtr &abs) {
size_t params_size = func_graph->parameters().size();
size_t io_monad_location = params_size;
// Search for existed parameters, return it if found.
for (auto &node : func_graph->parameters()) {
for (size_t i = 0; i < params_size; i++) {
auto &node = func_graph->parameters()[i];
auto para = dyn_cast<Parameter>(node);
if (para == nullptr) {
continue;
@ -51,13 +54,23 @@ AnfNodePtr AddMonadParameter(const FuncGraphPtr &func_graph, const std::string &
if (para_abs && *para_abs == *abs) {
return para;
}
if (HasAbstractIOMonad(para)) {
io_monad_location = i;
}
}
// Create a new parameter if not existed.
auto para = std::make_shared<Parameter>(func_graph);
para->set_name(name);
para->debug_info()->set_name(name);
para->set_abstract(abs);
func_graph->add_parameter(para);
// If io monad parameter added before u monad parameter, should insert u monad before io monad in parameters
if (io_monad_location != params_size && abs->isa<abstract::AbstractUMonad>()) {
std::vector<AnfNodePtr> params = func_graph->parameters();
params.insert(params.begin() + io_monad_location, para);
func_graph->set_parameters(params);
} else {
func_graph->add_parameter(para);
}
return para;
}

Loading…
Cancel
Save