fix-bug-of-misjudge-dynamic-graph-structure-in-pynative

pull/8855/head
lvliang 4 years ago
parent 497d3a42ca
commit 7268bdad15

@ -106,6 +106,7 @@ const char NAMED_PRIMITIVE_COMPARE[] = "Compare";
const char NAMED_PRIMITIVE_NAMECONSTANT[] = "NameConstant"; const char NAMED_PRIMITIVE_NAMECONSTANT[] = "NameConstant";
const char NAMED_PRIMITIVE_COMPARATORS[] = "comparators"; const char NAMED_PRIMITIVE_COMPARATORS[] = "comparators";
const char NAMED_PRIMITIVE_SLICE[] = "slice"; const char NAMED_PRIMITIVE_SLICE[] = "slice";
const char NAMED_PRIMITIVE_NAME[] = "Name";
const char NAMED_PRIMITIVE_NUM[] = "Num"; const char NAMED_PRIMITIVE_NUM[] = "Num";
const char NAMED_PRIMITIVE_STR[] = "Str"; const char NAMED_PRIMITIVE_STR[] = "Str";
const char NAMED_PRIMITIVE_ITER[] = "iter"; const char NAMED_PRIMITIVE_ITER[] = "iter";

@ -69,6 +69,8 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"}; const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"}; const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"};
const std::set<std::string> ignore_judge_dynamic_cell = {"Cell mindspore.nn.layer.basic.Dense",
"Cell mindspore.nn.probability.distribution.normal.Normal"};
const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE, const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE,
parse::NAMED_PRIMITIVE_NAMECONSTANT, parse::NAMED_PRIMITIVE_NAMECONSTANT,
parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR}; parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};

@ -1352,6 +1352,16 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &
return cell_id; return cell_id;
} }
std::string PynativeExecutor::GetCellInfo(const py::object &cell) {
if (py::isinstance<Cell>(cell)) {
auto c_cell = py::cast<CellPtr>(cell);
MS_EXCEPTION_IF_NULL(c_cell);
auto cell_info = c_cell->ToString();
return cell_info;
}
return "";
}
std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node, std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
parse::AstMainType type) { parse::AstMainType type) {
MS_EXCEPTION_IF_NULL(ast); MS_EXCEPTION_IF_NULL(ast);
@ -1372,6 +1382,16 @@ std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAs
return node_name; return node_name;
} }
void PynativeExecutor::ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
MS_EXCEPTION_IF_NULL(ast);
py::list args = ast->GetArgs(fn_node);
for (size_t i = 1; i < args.size(); i++) {
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
MS_LOG(DEBUG) << "Input arg name: " << arg_name;
cell_input_args_.emplace(arg_name);
}
}
bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) { bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
MS_LOG(DEBUG) << "Parse if/while expr"; MS_LOG(DEBUG) << "Parse if/while expr";
py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST); py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
@ -1383,14 +1403,25 @@ bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAs
MS_LOG(DEBUG) << "Get comparators node falied!"; MS_LOG(DEBUG) << "Get comparators node falied!";
return false; return false;
} }
const auto &left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR); auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
const auto &right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR); auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
MS_LOG(DEBUG) << "left is " << left << " right is " << right; if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR);
}
MS_LOG(DEBUG) << "Left is " << left << " Right is " << right;
if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() || if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() ||
unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) { unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) {
return true; return true;
} }
} }
// if flag:
if (node_name == parse::NAMED_PRIMITIVE_NAME) {
std::string id = py::cast<std::string>(test_node.attr("id"));
if (cell_input_args_.find(id) != cell_input_args_.end()) {
return true;
}
}
return false; return false;
} }
@ -1405,7 +1436,9 @@ bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst
py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE); py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE);
py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE); py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE);
const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR); const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR);
return unchanged_named_primitive.find(node_name_in_slice_node) == unchanged_named_primitive.end(); if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end()) {
return true;
}
} }
} }
return false; return false;
@ -1414,9 +1447,13 @@ bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst
bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) { bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
MS_LOG(DEBUG) << "Parse for expr"; MS_LOG(DEBUG) << "Parse for expr";
py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY); py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
if (py::isinstance<py::none>(body_node)) {
MS_LOG(DEBUG) << "Parse body of for expression is none!";
return false;
}
py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN); py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN);
size_t count = LongToSize(pcount); size_t count = LongToSize(pcount);
MS_LOG(DEBUG) << "The for nodes count is " << count; MS_LOG(DEBUG) << "The for nodes count in body is " << count;
for (size_t i = 0; i < count; ++i) { for (size_t i = 0; i < count; ++i) {
auto it = py::cast<py::list>(body_node)[i]; auto it = py::cast<py::list>(body_node)[i];
const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT); const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT);
@ -1427,28 +1464,16 @@ bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &
return false; return false;
} }
bool PynativeExecutor::IsDynamicCell(const py::object &cell) { bool PynativeExecutor::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
std::string cell_info; MS_EXCEPTION_IF_NULL(ast);
if (py::isinstance<Cell>(cell)) { py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
auto c_cell = py::cast<CellPtr>(cell); if (py::isinstance<py::none>(func_obj)) {
MS_EXCEPTION_IF_NULL(c_cell); MS_LOG(DEBUG) << "Parse body of cell is none!";
cell_info = c_cell->ToString();
}
if (cell_info.find("nn.layer.basic.Dense") != string::npos) {
return false;
}
// using ast parse to check whether the construct of cell will be changed
auto ast = std::make_shared<parse::ParseAst>(cell);
bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
if (!success) {
MS_LOG(ERROR) << "Parse code to ast tree failed";
return false; return false;
} }
py::object nodes = ast->GetAstNode();
py::object func_obj = parse::python_adapter::GetPyObjAttr(nodes, parse::NAMED_PRIMITIVE_BODY);
py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN); py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN);
size_t count = IntToSize(pcount); size_t count = IntToSize(pcount);
MS_LOG(DEBUG) << "The nodes count is " << count; MS_LOG(DEBUG) << "The nodes count in body is " << count;
bool ret = false; bool ret = false;
for (size_t i = 0; i < count; ++i) { for (size_t i = 0; i < count; ++i) {
auto node = py::cast<py::list>(func_obj)[i]; auto node = py::cast<py::list>(func_obj)[i];
@ -1461,13 +1486,35 @@ bool PynativeExecutor::IsDynamicCell(const py::object &cell) {
ret = ParseIfWhileExprNode(ast, node); ret = ParseIfWhileExprNode(ast, node);
} }
if (ret) { if (ret) {
MS_LOG(INFO) << "Cur cell is dynamic"; MS_LOG(INFO) << "Current cell is dynamic!";
break; break;
} }
} }
return ret; return ret;
} }
bool PynativeExecutor::IsDynamicCell(const py::object &cell) {
std::string cell_info = GetCellInfo(cell);
if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) {
return false;
}
// using ast parse to check whether the construct of cell will be changed
auto ast = std::make_shared<parse::ParseAst>(cell);
bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
if (!success) {
MS_LOG(ERROR) << "Parse code to ast tree failed";
return false;
}
py::object fn_node = ast->GetAstNode();
// get the name of input args as the initialize of dynamic_variables
ParseInputArgs(ast, fn_node);
// parse body context
bool ret = false;
ret = ParseBodyContext(ast, fn_node);
cell_input_args_.clear();
return ret;
}
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetCellId(cell, args); auto cell_id = GetCellId(cell, args);
MS_LOG(DEBUG) << "NewGraphInner start, args size: " << args.size() << ", cell id: " << cell_id; MS_LOG(DEBUG) << "NewGraphInner start, args size: " << args.size() << ", cell id: " << cell_id;

@ -103,6 +103,9 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
// check cell struct // check cell struct
bool IsDynamicCell(const py::object &cell); bool IsDynamicCell(const py::object &cell);
std::string GetCellInfo(const py::object &cell);
void ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
bool ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
bool ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node); bool ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
bool ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node); bool ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
bool ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node); bool ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
@ -186,6 +189,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
// record all info of a graph // record all info of a graph
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_; std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
std::unordered_set<std::string> cell_input_args_;
std::unordered_map<std::string, bool> cell_dynamic_map_; std::unordered_map<std::string, bool> cell_dynamic_map_;
std::unordered_map<std::string, ResourcePtr> cell_resource_map_; std::unordered_map<std::string, ResourcePtr> cell_resource_map_;
std::unordered_map<std::string, std::pair<FuncGraphPtr, bool>> cell_graph_map_; std::unordered_map<std::string, std::pair<FuncGraphPtr, bool>> cell_graph_map_;

Loading…
Cancel
Save