diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index 8f4bb554b8..249678e30c 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -106,6 +106,7 @@ const char NAMED_PRIMITIVE_COMPARE[] = "Compare"; const char NAMED_PRIMITIVE_NAMECONSTANT[] = "NameConstant"; const char NAMED_PRIMITIVE_COMPARATORS[] = "comparators"; const char NAMED_PRIMITIVE_SLICE[] = "slice"; +const char NAMED_PRIMITIVE_NAME[] = "Name"; const char NAMED_PRIMITIVE_NUM[] = "Num"; const char NAMED_PRIMITIVE_STR[] = "Str"; const char NAMED_PRIMITIVE_ITER[] = "iter"; diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index 0213aca229..8203d46ee2 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -69,6 +69,8 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args); const std::set ignore_infer_prim = {"make_ref", "mixed_precision_cast"}; const std::set force_infer_prim = {"TopK", "DropoutGenMask"}; +const std::set ignore_judge_dynamic_cell = {"Cell mindspore.nn.layer.basic.Dense", + "Cell mindspore.nn.probability.distribution.normal.Normal"}; const std::set unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE, parse::NAMED_PRIMITIVE_NAMECONSTANT, parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR}; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 8f5c0335d1..6e92cc9842 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1352,6 +1352,16 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args & return cell_id; } +std::string PynativeExecutor::GetCellInfo(const py::object &cell) { + if (py::isinstance(cell)) { + auto c_cell = py::cast(cell); + MS_EXCEPTION_IF_NULL(c_cell); + auto cell_info = c_cell->ToString(); + return cell_info; + } + return ""; +} + std::string PynativeExecutor::ParseNodeName(const std::shared_ptr &ast, const py::object &node, parse::AstMainType type) { MS_EXCEPTION_IF_NULL(ast); @@ -1372,6 +1382,16 @@ std::string PynativeExecutor::ParseNodeName(const std::shared_ptr &ast, const py::object &fn_node) { + MS_EXCEPTION_IF_NULL(ast); + py::list args = ast->GetArgs(fn_node); + for (size_t i = 1; i < args.size(); i++) { + std::string arg_name = py::cast(args[i].attr("arg")); + MS_LOG(DEBUG) << "Input arg name: " << arg_name; + cell_input_args_.emplace(arg_name); + } +} + bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr &ast, const py::object &node) { MS_LOG(DEBUG) << "Parse if/while expr"; py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST); @@ -1383,14 +1403,25 @@ bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr(test_node.attr("id")); + if (cell_input_args_.find(id) != cell_input_args_.end()) { + return true; + } + } return false; } @@ -1405,7 +1436,9 @@ bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr &ast, const py::object &node) { MS_LOG(DEBUG) << "Parse for expr"; py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY); + if (py::isinstance(body_node)) { + MS_LOG(DEBUG) << "Parse body of for expression is none!"; + return false; + } py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN); size_t count = LongToSize(pcount); - MS_LOG(DEBUG) << "The for nodes count is " << count; + MS_LOG(DEBUG) << "The for nodes count in body is " << count; for (size_t i = 0; i < count; ++i) { auto it = py::cast(body_node)[i]; const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT); @@ -1427,28 +1464,16 @@ bool PynativeExecutor::ParseForExprNode(const std::shared_ptr & return false; } -bool PynativeExecutor::IsDynamicCell(const py::object &cell) { - std::string cell_info; - if (py::isinstance(cell)) { - auto c_cell = py::cast(cell); - MS_EXCEPTION_IF_NULL(c_cell); - cell_info = c_cell->ToString(); - } - if (cell_info.find("nn.layer.basic.Dense") != string::npos) { - return false; - } - // using ast parse to check whether the construct of cell will be changed - auto ast = std::make_shared(cell); - bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD); - if (!success) { - MS_LOG(ERROR) << "Parse code to ast tree failed"; +bool PynativeExecutor::ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node) { + MS_EXCEPTION_IF_NULL(ast); + py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY); + if (py::isinstance(func_obj)) { + MS_LOG(DEBUG) << "Parse body of cell is none!"; return false; } - py::object nodes = ast->GetAstNode(); - py::object func_obj = parse::python_adapter::GetPyObjAttr(nodes, parse::NAMED_PRIMITIVE_BODY); py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN); size_t count = IntToSize(pcount); - MS_LOG(DEBUG) << "The nodes count is " << count; + MS_LOG(DEBUG) << "The nodes count in body is " << count; bool ret = false; for (size_t i = 0; i < count; ++i) { auto node = py::cast(func_obj)[i]; @@ -1461,13 +1486,35 @@ bool PynativeExecutor::IsDynamicCell(const py::object &cell) { ret = ParseIfWhileExprNode(ast, node); } if (ret) { - MS_LOG(INFO) << "Cur cell is dynamic"; + MS_LOG(INFO) << "Current cell is dynamic!"; break; } } return ret; } +bool PynativeExecutor::IsDynamicCell(const py::object &cell) { + std::string cell_info = GetCellInfo(cell); + if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) { + return false; + } + // using ast parse to check whether the construct of cell will be changed + auto ast = std::make_shared(cell); + bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD); + if (!success) { + MS_LOG(ERROR) << "Parse code to ast tree failed"; + return false; + } + py::object fn_node = ast->GetAstNode(); + // get the name of input args as the initialize of dynamic_variables + ParseInputArgs(ast, fn_node); + // parse body context + bool ret = false; + ret = ParseBodyContext(ast, fn_node); + cell_input_args_.clear(); + return ret; +} + void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { auto cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "NewGraphInner start, args size: " << args.size() << ", cell id: " << cell_id; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index a9886b55b9..b37902888b 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -103,6 +103,9 @@ class PynativeExecutor : public std::enable_shared_from_this { // check cell struct bool IsDynamicCell(const py::object &cell); + std::string GetCellInfo(const py::object &cell); + void ParseInputArgs(const std::shared_ptr &ast, const py::object &fn_node); + bool ParseBodyContext(const std::shared_ptr &ast, const py::object &fn_node); bool ParseIfWhileExprNode(const std::shared_ptr &ast, const py::object &node); bool ParseAssignExprNode(const std::shared_ptr &ast, const py::object &node); bool ParseForExprNode(const std::shared_ptr &ast, const py::object &node); @@ -186,6 +189,7 @@ class PynativeExecutor : public std::enable_shared_from_this { // record all info of a graph std::unordered_map graph_info_map_; + std::unordered_set cell_input_args_; std::unordered_map cell_dynamic_map_; std::unordered_map cell_resource_map_; std::unordered_map> cell_graph_map_;