diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 02e3c4230a..bf084bc80e 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -109,6 +109,7 @@ void Parser::BuildMethodMap() { expr_method_map_["Name"] = &Parser::ParseName; expr_method_map_["Num"] = &Parser::ParseNum; expr_method_map_["Str"] = &Parser::ParseStr; + expr_method_map_["Constant"] = &Parser::ParseConstant; expr_method_map_["NameConstant"] = &Parser::ParseNameConstant; expr_method_map_["Call"] = &Parser::ParseCall; expr_method_map_["IfExp"] = &Parser::ParseIfExp; @@ -557,6 +558,36 @@ AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) { return NewValueNode(str_s); } +AnfNodePtr Parser::ParseConstant(const FunctionBlockPtr &, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Constant"; + py::object obj = python_adapter::GetPyObjAttr(node, "value"); + TraceGuard trace_guard(GetLocation(node)); + if (py::isinstance(obj)) { + MS_LOG(INFO) << "The Constant is bool:" << (std::string)py::str(obj); + auto data = py::cast(obj); + return NewValueNode(data); + } else if (py::isinstance(obj)) { + MS_LOG(INFO) << "The Constant is int64_t:" << (std::string)py::str(obj); + auto data = py::cast(obj); + return NewValueNode(data); + } else if (py::isinstance(obj)) { + MS_LOG(INFO) << "The Constant is float:" << (std::string)py::str(obj); + auto data = py::cast(obj); + return NewValueNode(data); + } else if (py::isinstance(obj)) { + MS_LOG(INFO) << "The Constant is string:" << (std::string)py::str(obj); + auto data = py::cast(obj); + return NewValueNode(data); + } else if (py::isinstance(obj)) { + MS_LOG(INFO) << "The Constant is none:" << (std::string)py::str(obj); + return NewValueNode(kNone); + } else { + // no else actually + MS_EXCEPTION(TypeError) << "Unsupported Constant type : " << (std::string)py::str(obj) + << GetLocation(node)->ToString(); + } +} + AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) { MS_LOG(DEBUG) << "Process ast NameConstant"; py::object obj = python_adapter::GetPyObjAttr(node, "value"); diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.h b/mindspore/ccsrc/pipeline/jit/parse/parse.h index 4f349efa30..1941f76dc6 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.h @@ -143,6 +143,8 @@ class Parser { AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node); // process a string variable AnfNodePtr ParseStr(const FunctionBlockPtr &block, const py::object &node); + // process a Constant + AnfNodePtr ParseConstant(const FunctionBlockPtr &block, const py::object &node); // process a name AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node); // process a function call