|
|
@ -1352,6 +1352,16 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &
|
|
|
|
return cell_id;
|
|
|
|
return cell_id;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::string PynativeExecutor::GetCellInfo(const py::object &cell) {
|
|
|
|
|
|
|
|
if (py::isinstance<Cell>(cell)) {
|
|
|
|
|
|
|
|
auto c_cell = py::cast<CellPtr>(cell);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_cell);
|
|
|
|
|
|
|
|
auto cell_info = c_cell->ToString();
|
|
|
|
|
|
|
|
return cell_info;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return "";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
|
|
|
std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
|
|
|
parse::AstMainType type) {
|
|
|
|
parse::AstMainType type) {
|
|
|
|
MS_EXCEPTION_IF_NULL(ast);
|
|
|
|
MS_EXCEPTION_IF_NULL(ast);
|
|
|
@ -1372,6 +1382,16 @@ std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAs
|
|
|
|
return node_name;
|
|
|
|
return node_name;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ast);
|
|
|
|
|
|
|
|
py::list args = ast->GetArgs(fn_node);
|
|
|
|
|
|
|
|
for (size_t i = 1; i < args.size(); i++) {
|
|
|
|
|
|
|
|
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Input arg name: " << arg_name;
|
|
|
|
|
|
|
|
cell_input_args_.emplace(arg_name);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
|
|
|
|
bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
|
|
|
|
MS_LOG(DEBUG) << "Parse if/while expr";
|
|
|
|
MS_LOG(DEBUG) << "Parse if/while expr";
|
|
|
|
py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
|
|
|
|
py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
|
|
|
@ -1383,14 +1403,25 @@ bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAs
|
|
|
|
MS_LOG(DEBUG) << "Get comparators node falied!";
|
|
|
|
MS_LOG(DEBUG) << "Get comparators node falied!";
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
const auto &left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
|
|
|
|
auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
|
|
|
|
const auto &right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
|
|
|
|
auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
|
|
|
|
MS_LOG(DEBUG) << "left is " << left << " right is " << right;
|
|
|
|
if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
|
|
|
|
|
|
|
|
py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE);
|
|
|
|
|
|
|
|
left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Left is " << left << " Right is " << right;
|
|
|
|
if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() ||
|
|
|
|
if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() ||
|
|
|
|
unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) {
|
|
|
|
unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) {
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// if flag:
|
|
|
|
|
|
|
|
if (node_name == parse::NAMED_PRIMITIVE_NAME) {
|
|
|
|
|
|
|
|
std::string id = py::cast<std::string>(test_node.attr("id"));
|
|
|
|
|
|
|
|
if (cell_input_args_.find(id) != cell_input_args_.end()) {
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -1405,7 +1436,9 @@ bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst
|
|
|
|
py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE);
|
|
|
|
py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE);
|
|
|
|
py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE);
|
|
|
|
py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE);
|
|
|
|
const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR);
|
|
|
|
const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR);
|
|
|
|
return unchanged_named_primitive.find(node_name_in_slice_node) == unchanged_named_primitive.end();
|
|
|
|
if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end()) {
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
@ -1414,9 +1447,13 @@ bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst
|
|
|
|
bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
|
|
|
|
bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
|
|
|
|
MS_LOG(DEBUG) << "Parse for expr";
|
|
|
|
MS_LOG(DEBUG) << "Parse for expr";
|
|
|
|
py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
|
|
|
|
py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
|
|
|
|
|
|
|
|
if (py::isinstance<py::none>(body_node)) {
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Parse body of for expression is none!";
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN);
|
|
|
|
py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN);
|
|
|
|
size_t count = LongToSize(pcount);
|
|
|
|
size_t count = LongToSize(pcount);
|
|
|
|
MS_LOG(DEBUG) << "The for nodes count is " << count;
|
|
|
|
MS_LOG(DEBUG) << "The for nodes count in body is " << count;
|
|
|
|
for (size_t i = 0; i < count; ++i) {
|
|
|
|
for (size_t i = 0; i < count; ++i) {
|
|
|
|
auto it = py::cast<py::list>(body_node)[i];
|
|
|
|
auto it = py::cast<py::list>(body_node)[i];
|
|
|
|
const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT);
|
|
|
|
const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT);
|
|
|
@ -1427,28 +1464,16 @@ bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool PynativeExecutor::IsDynamicCell(const py::object &cell) {
|
|
|
|
bool PynativeExecutor::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) {
|
|
|
|
std::string cell_info;
|
|
|
|
MS_EXCEPTION_IF_NULL(ast);
|
|
|
|
if (py::isinstance<Cell>(cell)) {
|
|
|
|
py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY);
|
|
|
|
auto c_cell = py::cast<CellPtr>(cell);
|
|
|
|
if (py::isinstance<py::none>(func_obj)) {
|
|
|
|
MS_EXCEPTION_IF_NULL(c_cell);
|
|
|
|
MS_LOG(DEBUG) << "Parse body of cell is none!";
|
|
|
|
cell_info = c_cell->ToString();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (cell_info.find("nn.layer.basic.Dense") != string::npos) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// using ast parse to check whether the construct of cell will be changed
|
|
|
|
|
|
|
|
auto ast = std::make_shared<parse::ParseAst>(cell);
|
|
|
|
|
|
|
|
bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
|
|
|
|
|
|
|
|
if (!success) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Parse code to ast tree failed";
|
|
|
|
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
py::object nodes = ast->GetAstNode();
|
|
|
|
|
|
|
|
py::object func_obj = parse::python_adapter::GetPyObjAttr(nodes, parse::NAMED_PRIMITIVE_BODY);
|
|
|
|
|
|
|
|
py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN);
|
|
|
|
py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN);
|
|
|
|
size_t count = IntToSize(pcount);
|
|
|
|
size_t count = IntToSize(pcount);
|
|
|
|
MS_LOG(DEBUG) << "The nodes count is " << count;
|
|
|
|
MS_LOG(DEBUG) << "The nodes count in body is " << count;
|
|
|
|
bool ret = false;
|
|
|
|
bool ret = false;
|
|
|
|
for (size_t i = 0; i < count; ++i) {
|
|
|
|
for (size_t i = 0; i < count; ++i) {
|
|
|
|
auto node = py::cast<py::list>(func_obj)[i];
|
|
|
|
auto node = py::cast<py::list>(func_obj)[i];
|
|
|
@ -1461,13 +1486,35 @@ bool PynativeExecutor::IsDynamicCell(const py::object &cell) {
|
|
|
|
ret = ParseIfWhileExprNode(ast, node);
|
|
|
|
ret = ParseIfWhileExprNode(ast, node);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (ret) {
|
|
|
|
if (ret) {
|
|
|
|
MS_LOG(INFO) << "Cur cell is dynamic";
|
|
|
|
MS_LOG(INFO) << "Current cell is dynamic!";
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return ret;
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool PynativeExecutor::IsDynamicCell(const py::object &cell) {
|
|
|
|
|
|
|
|
std::string cell_info = GetCellInfo(cell);
|
|
|
|
|
|
|
|
if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// using ast parse to check whether the construct of cell will be changed
|
|
|
|
|
|
|
|
auto ast = std::make_shared<parse::ParseAst>(cell);
|
|
|
|
|
|
|
|
bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
|
|
|
|
|
|
|
|
if (!success) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Parse code to ast tree failed";
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
py::object fn_node = ast->GetAstNode();
|
|
|
|
|
|
|
|
// get the name of input args as the initialize of dynamic_variables
|
|
|
|
|
|
|
|
ParseInputArgs(ast, fn_node);
|
|
|
|
|
|
|
|
// parse body context
|
|
|
|
|
|
|
|
bool ret = false;
|
|
|
|
|
|
|
|
ret = ParseBodyContext(ast, fn_node);
|
|
|
|
|
|
|
|
cell_input_args_.clear();
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
|
|
|
|
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
|
|
|
|
auto cell_id = GetCellId(cell, args);
|
|
|
|
auto cell_id = GetCellId(cell, args);
|
|
|
|
MS_LOG(DEBUG) << "NewGraphInner start, args size: " << args.size() << ", cell id: " << cell_id;
|
|
|
|
MS_LOG(DEBUG) << "NewGraphInner start, args size: " << args.size() << ", cell id: " << cell_id;
|
|
|
|