diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index 167165754d..7c52f8ccf1 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -90,6 +90,24 @@ const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj"; // define the common name const char NAMED_PRIMITIVE_LEN[] = "len"; +const char NAMED_PRIMITIVE_BODY[] = "body"; +const char NAMED_PRIMITIVE_ASSIGN[] = "Assign"; +const char NAMED_PRIMITIVE_FOR[] = "For"; +const char NAMED_PRIMITIVE_IF[] = "If"; +const char NAMED_PRIMITIVE_WHILE[] = "While"; +const char NAMED_PRIMITIVE_VALUE[] = "value"; +const char NAMED_PRIMITIVE_FUNC[] = "func"; +const char NAMED_PRIMITIVE_TEST[] = "test"; +const char NAMED_PRIMITIVE_LEFT[] = "left"; +const char NAMED_PRIMITIVE_CALL[] = "Call"; +const char NAMED_PRIMITIVE_SUBSCRIPT[] = "Subscript"; +const char NAMED_PRIMITIVE_ATTRIBUTE[] = "Attribute"; +const char NAMED_PRIMITIVE_COMPARE[] = "Compare"; +const char NAMED_PRIMITIVE_NAMECONSTANT[] = "NameConstant"; +const char NAMED_PRIMITIVE_COMPARATORS[] = "comparators"; +const char NAMED_PRIMITIVE_SLICE[] = "slice"; +const char NAMED_PRIMITIVE_NUM[] = "Num"; +const char NAMED_PRIMITIVE_STR[] = "Str"; const char NAMED_PRIMITIVE_ITER[] = "iter"; const char NAMED_PRIMITIVE_NEXT[] = "next"; const char NAMED_PRIMITIVE_GETITEM[] = "getitem"; @@ -104,6 +122,7 @@ const char NAMED_METAGRAPH_UNPACKCALL[] = "unpack_call"; // define NAMED_PRIMITIVE_GETATTR "getattr" // define python inline attr +const char PYTHON_GET_METHOD_LEN[] = "__len__"; const char PYTHON_GET_METHOD_SELF_CLASS[] = "__self__"; const char PYTHON_GET_OBJ_DESC[] = "__str__"; diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index 0516fc3002..099588f974 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -28,6 +28,7 @@ #include "pybind11/pybind11.h" #include "ir/anf.h" #include "pybind_api/ir/primitive_py.h" +#include "pipeline/jit/parse/parse.h" #include "abstract/abstract_value.h" namespace mindspore { @@ -65,6 +66,9 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args); const std::set ignore_infer_prim = {"make_ref", "mixed_precision_cast"}; const std::set force_infer_prim = {"TopK", "DropoutGenMask"}; +const std::set unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE, + parse::NAMED_PRIMITIVE_NAMECONSTANT, + parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR}; } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 8434cf5e31..ecb14dfbee 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -28,6 +28,7 @@ #include "pybind_api/ir/tensor_py.h" #include "ir/param_info.h" #include "ir/anf.h" +#include "ir/cell.h" #include "ir/tensor.h" #include "utils/any.h" #include "utils/utils.h" @@ -36,10 +37,8 @@ #include "utils/config_manager.h" #include "utils/convert_utils_py.h" #include "frontend/operator/ops.h" -#include "frontend/operator/composite/composite.h" #include "frontend/operator/composite/do_signature.h" #include "pipeline/jit/parse/data_converter.h" -#include "pipeline/jit/parse/parse_base.h" #include "pipeline/jit/parse/resolve.h" #include "pipeline/jit/static_analysis/prim.h" #include "backend/session/session_factory.h" @@ -69,8 +68,7 @@ const size_t PTR_LEN = 15; const std::set vm_operators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient", "mixed_precision_cast"}; -namespace mindspore { -namespace pynative { +namespace mindspore::pynative { static std::shared_ptr session = nullptr; PynativeExecutorPtr PynativeExecutor::executor_ = nullptr; std::mutex PynativeExecutor::instance_lock_; @@ -947,7 +945,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { if (op_mask) { MS_LOG(DEBUG) << "Cell paramsters(weights)"; // get the parameter name from parameter object - auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name"); + auto name_attr = parse::python_adapter::GetPyObjAttr(obj, "name"); if (py::isinstance(name_attr)) { MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; } @@ -1221,7 +1219,7 @@ py::tuple PynativeExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_e MsBackendPolicy PynativeExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) { MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; - mindspore::parse::python_adapter::set_python_env_flag(true); + parse::python_adapter::set_python_env_flag(true); MsBackendPolicy backend_policy; #if (!defined ENABLE_GE) auto ms_context = MsContext::GetInstance(); @@ -1350,18 +1348,136 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args & return cell_id; } +std::string PynativeExecutor::ParseNodeName(const std::shared_ptr &ast, const py::object &node, + parse::AstMainType type) { + MS_EXCEPTION_IF_NULL(ast); + if (py::isinstance(node)) { + MS_LOG(DEBUG) << "Get none type node!"; + return ""; + } + auto node_type = ast->GetNodeType(node); + MS_EXCEPTION_IF_NULL(node_type); + // check node type + parse::AstMainType node_main_type = node_type->main_type(); + if (node_main_type != type) { + MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type; + return ""; + } + std::string node_name = node_type->node_name(); + MS_LOG(DEBUG) << "Ast node is " << node_name; + return node_name; +} + +bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr &ast, const py::object &node) { + MS_LOG(DEBUG) << "Parse if/while expr"; + py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST); + const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR); + if (node_name == parse::NAMED_PRIMITIVE_COMPARE) { + py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT); + py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS); + if (comparators_node.empty()) { + MS_LOG(DEBUG) << "Get comparators node falied!"; + return false; + } + const auto &left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR); + const auto &right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR); + MS_LOG(DEBUG) << "left is " << left << " right is " << right; + if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() || + unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) { + return true; + } + } + return false; +} + +bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr &ast, const py::object &node) { + MS_LOG(DEBUG) << "Parse assign expr"; + py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE); + const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR); + if (node_name == parse::NAMED_PRIMITIVE_CALL) { + py::object func_node = parse::python_adapter::GetPyObjAttr(value_node, parse::NAMED_PRIMITIVE_FUNC); + const auto &func_name = ParseNodeName(ast, func_node, parse::AST_MAIN_TYPE_EXPR); + if (func_name == parse::NAMED_PRIMITIVE_SUBSCRIPT) { + py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE); + py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE); + const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR); + return unchanged_named_primitive.find(node_name_in_slice_node) == unchanged_named_primitive.end(); + } + } + return false; +} + +bool PynativeExecutor::ParseForExprNode(const std::shared_ptr &ast, const py::object &node) { + MS_LOG(DEBUG) << "Parse for expr"; + py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY); + py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN); + size_t count = LongToSize(pcount); + MS_LOG(DEBUG) << "The for nodes count is " << count; + for (size_t i = 0; i < count; ++i) { + auto it = py::cast(body_node)[i]; + const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT); + if (node_name == parse::NAMED_PRIMITIVE_ASSIGN && ParseAssignExprNode(ast, it)) { + return true; + } + } + return false; +} + +bool PynativeExecutor::IsDynamicCell(const py::object &cell) { + std::string cell_info; + if (py::isinstance(cell)) { + auto c_cell = py::cast(cell); + MS_EXCEPTION_IF_NULL(c_cell); + cell_info = c_cell->ToString(); + } + if (cell_info.find("nn.layer.basic.Dense") != string::npos) { + return false; + } + // using ast parse to check whether the construct of cell will be changed + auto ast = std::make_shared(cell); + bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD); + if (!success) { + MS_LOG(ERROR) << "Parse code to ast tree failed"; + return false; + } + py::object nodes = ast->GetAstNode(); + py::object func_obj = parse::python_adapter::GetPyObjAttr(nodes, parse::NAMED_PRIMITIVE_BODY); + py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN); + size_t count = IntToSize(pcount); + MS_LOG(DEBUG) << "The nodes count is " << count; + bool ret = false; + for (size_t i = 0; i < count; ++i) { + auto node = py::cast(func_obj)[i]; + const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT); + if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) { + ret = ParseAssignExprNode(ast, node); + } else if (node_name == parse::NAMED_PRIMITIVE_FOR) { + ret = ParseForExprNode(ast, node); + } else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) { + ret = ParseIfWhileExprNode(ast, node); + } + if (ret) { + MS_LOG(INFO) << "Cur cell is dynamic"; + break; + } + } + return ret; +} + void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { auto cell_id = GetCellId(cell, args); - MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id; - if (!dynamic_shape && graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end()) { + MS_LOG(DEBUG) << "NewGraphInner start, args size: " << args.size() << ", cell id: " << cell_id; + // check whether cell needed to construct grad graph + if (graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end() && !dynamic_cell_) { auto it = cell_resource_map_.find(cell_id); if (it != cell_resource_map_.end()) { resource_ = it->second; + MS_EXCEPTION_IF_NULL(resource_); } - MS_LOG(DEBUG) << "Newgraph already compiled"; + MS_LOG(DEBUG) << "Graph already compiled"; return; } - + // init resource for constructing forward graph and grad graph auto g = std::make_shared(); if (graph_context_.empty()) { MakeNewTopGraph(cell_id, args, g); @@ -1380,6 +1496,11 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg set_node_map(curr_g_, param, new_param, true); set_node_map(curr_g_, param_obj, new_param); } + // check whether the constrcut of cell will be changed + if (!dynamic_cell_) { + dynamic_cell_ = IsDynamicCell(cell); + MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << dynamic_cell_; + } } void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g) { @@ -1391,23 +1512,14 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar } } } - if (dynamic_shape) { - auto it = df_builder_map_.find(cell_id); - if (it != df_builder_map_.end()) { - df_builder_map_.erase(cell_id); - } - auto ic = cell_resource_map_.find(cell_id); - if (ic != cell_resource_map_.end()) { - cell_resource_map_.erase(cell_id); - } - } + top_g_ = curr_g_ = g; + dynamic_cell_ = false; // a df builder is built for every top function graph df_builder_ = std::make_shared(); - df_builder_map_.emplace(cell_id, std::make_pair(df_builder_, nullptr)); - top_g_ = curr_g_ = g; + df_builder_map_[cell_id] = std::make_pair(df_builder_, nullptr); resource_ = std::make_shared(); resource_->results()[pipeline::kPynativeGraphId] = graph_id_++; - cell_resource_map_.emplace(cell_id, resource_); + cell_resource_map_[cell_id] = resource_; MS_LOG(DEBUG) << "New top graph for " << cell_id; first_grad_step_ = true; top_graph_cells_.emplace(cell_id); @@ -1452,12 +1564,11 @@ void PynativeExecutor::set_tuple_node_map(const FuncGraphPtr &g, const py::objec void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { auto cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id; - if (!dynamic_shape && graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end()) { + if (graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end() && !dynamic_cell_) { MS_LOG(DEBUG) << "Endgraph already compiled"; return; } - - cell_graph_map_.emplace(cell_id, std::make_pair(curr_g_, false)); + cell_graph_map_[cell_id] = std::make_pair(curr_g_, false); auto out_id = GetId(out); // x =op1, y =op2, return (x, y) if (graph_info_map_[curr_g_].node_map.find(out_id) == graph_info_map_[curr_g_].node_map.end()) { @@ -1563,8 +1674,8 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje std::string cell_id = CheckCellChanged(grad, cell, weights, args, &sens_weights_changed); MS_LOG(DEBUG) << "GradNetInner cell_id " << cell_id; - if (!dynamic_shape && !sens_weights_changed.first && !sens_weights_changed.second && - cell_graph_map_.find(cell_id) != cell_graph_map_.end() && cell_graph_map_[cell_id].second) { + if (!sens_weights_changed.first && !sens_weights_changed.second && + cell_graph_map_.find(cell_id) != cell_graph_map_.end() && cell_graph_map_[cell_id].second && !dynamic_cell_) { if (cell_resource_map_.find(cell_id) == cell_resource_map_.end()) { MS_LOG(EXCEPTION) << "Can not find resource"; } @@ -1722,7 +1833,7 @@ std::vector PynativeExecutor::GetWeightsArgs(const py::object &weigh graph_info_map_[df_builder_].node_map.find(param_id) != graph_info_map_[df_builder_].node_map.end()) { para_node = graph_info_map_[df_builder_].node_map[param_id].first; } else { - auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(param, "name"); + auto name_attr = parse::python_adapter::GetPyObjAttr(param, "name"); if (py::isinstance(name_attr)) { MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; } @@ -1847,6 +1958,7 @@ void PynativeExecutor::Clear(const std::string &flag) { if (!flag.empty()) { MS_LOG(DEBUG) << "Clear cell res"; MapClear>(&cell_resource_map_, flag); + MapClear>(&cell_dynamic_map_, flag); MapClear>>(&cell_graph_map_, flag); MapClear>>(&cell_sw_map_, flag); MapClear>>(&df_builder_map_, flag); @@ -1892,6 +2004,7 @@ void PynativeExecutor::ClearRes() { df_builder_map_.clear(); cell_graph_map_.clear(); cell_resource_map_.clear(); + cell_dynamic_map_.clear(); node_abs_map_.clear(); top_graph_cells_.clear(); } @@ -1922,5 +2035,4 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), "Executor set grad flag."); })); -} // namespace pynative -} // namespace mindspore +} // namespace mindspore::pynative diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 0738a95104..9b707db0ad 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -100,6 +100,14 @@ class PynativeExecutor : public std::enable_shared_from_this { private: PynativeExecutor() = default; + // check cell struct + bool IsDynamicCell(const py::object &cell); + bool ParseIfWhileExprNode(const std::shared_ptr &ast, const py::object &node); + bool ParseAssignExprNode(const std::shared_ptr &ast, const py::object &node); + bool ParseForExprNode(const std::shared_ptr &ast, const py::object &node); + std::string ParseNodeName(const std::shared_ptr &ast, const py::object &node, + parse::AstMainType type); + // run op AnfNodePtr GetInput(const py::object &obj, bool op_mask); MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info); @@ -158,9 +166,9 @@ class PynativeExecutor : public std::enable_shared_from_this { static std::mutex instance_lock_; static int64_t graph_id_; bool grad_flag_{false}; + bool dynamic_cell_{false}; bool first_grad_step_{false}; bool grad_is_running{false}; - bool dynamic_shape{false}; // Used for construct grad graph FuncGraphPtr top_g_{nullptr}; @@ -173,6 +181,7 @@ class PynativeExecutor : public std::enable_shared_from_this { // record all info of a graph std::unordered_map graph_info_map_; + std::unordered_map cell_dynamic_map_; std::unordered_map cell_resource_map_; std::unordered_map> cell_graph_map_; // key: cell_id, value: (send_id, weigths_id), cache for sens and weight change