|
|
|
@ -22,6 +22,7 @@
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
|
|
|
|
|
#include "debug/trace.h"
|
|
|
|
|
#include "ir/tensor_py.h"
|
|
|
|
|
#include "ir/param_value_py.h"
|
|
|
|
|
#include "utils/any.h"
|
|
|
|
@ -66,6 +67,42 @@ PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
|
|
|
|
|
std::mutex PynativeExecutor::instance_lock_;
|
|
|
|
|
ResourcePtr PynativeExecutor::resource_;
|
|
|
|
|
|
|
|
|
|
template <typename... Args>
|
|
|
|
|
void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) {
|
|
|
|
|
try {
|
|
|
|
|
(executor->*method)(args...);
|
|
|
|
|
} catch (const py::error_already_set &ex) {
|
|
|
|
|
// print function call stack info before release
|
|
|
|
|
std::ostringstream oss;
|
|
|
|
|
trace::TraceGraphEval();
|
|
|
|
|
trace::GetEvalStackInfo(oss);
|
|
|
|
|
// call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
|
|
|
|
|
// these info from screen, no need to open log file to find these info
|
|
|
|
|
py::print(oss.str());
|
|
|
|
|
MS_LOG(ERROR) << oss.str();
|
|
|
|
|
PynativeExecutor::GetInstance()->Clean();
|
|
|
|
|
// re-throw this exception to Python interpreter to handle it
|
|
|
|
|
throw(py::error_already_set(ex));
|
|
|
|
|
} catch (const py::type_error &ex) {
|
|
|
|
|
PynativeExecutor::GetInstance()->Clean();
|
|
|
|
|
throw py::type_error(ex);
|
|
|
|
|
} catch (const py::value_error &ex) {
|
|
|
|
|
PynativeExecutor::GetInstance()->Clean();
|
|
|
|
|
throw py::value_error(ex);
|
|
|
|
|
} catch (const py::index_error &ex) {
|
|
|
|
|
PynativeExecutor::GetInstance()->Clean();
|
|
|
|
|
throw py::index_error(ex);
|
|
|
|
|
} catch (const std::exception &ex) {
|
|
|
|
|
PynativeExecutor::GetInstance()->Clean();
|
|
|
|
|
// re-throw this exception to Python interpreter to handle it
|
|
|
|
|
throw(std::runtime_error(ex.what()));
|
|
|
|
|
} catch (...) {
|
|
|
|
|
PynativeExecutor::GetInstance()->Clean();
|
|
|
|
|
std::string exName(abi::__cxa_current_exception_type()->name());
|
|
|
|
|
MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline ValuePtr PyAttrValue(const py::object &obj) {
|
|
|
|
|
ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
|
|
|
|
|
if (!converted_ret) {
|
|
|
|
@ -144,7 +181,7 @@ std::map<SignatureEnumDType, size_t> GetDstType(const py::tuple &py_args,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args,
|
|
|
|
|
py::list *out_args_list) {
|
|
|
|
|
py::list *const out_args_list) {
|
|
|
|
|
auto &py_args = *out_args;
|
|
|
|
|
py::tuple input_mask(args.size());
|
|
|
|
|
for (size_t i = 0; i < args.size(); ++i) {
|
|
|
|
@ -564,7 +601,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
|
|
|
|
|
return node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::tuple RunOp(const OpExecInfoPtr &op_exec_info, const py::args &args) {
|
|
|
|
|
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
|
|
|
|
|
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
|
|
|
|
|
mindspore::parse::python_adapter::set_python_env_flag(true);
|
|
|
|
|
MsBackendPolicy backend_policy;
|
|
|
|
@ -603,7 +640,7 @@ py::tuple RunOp(const OpExecInfoPtr &op_exec_info, const py::args &args) {
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::tuple RunOp(const py::args &args) {
|
|
|
|
|
py::tuple RunOpInner(const py::args &args) {
|
|
|
|
|
MS_LOG(DEBUG) << "RunOp start" << args.size();
|
|
|
|
|
py::list args_input = args[PY_INPUTS];
|
|
|
|
|
|
|
|
|
@ -623,7 +660,42 @@ py::tuple RunOp(const py::args &args) {
|
|
|
|
|
return value_ret;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return RunOp(op_exec_info, args_input);
|
|
|
|
|
return RunOpInner(op_exec_info, args_input);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::tuple RunOp(const py::args &args) {
|
|
|
|
|
try {
|
|
|
|
|
return RunOpInner(args);
|
|
|
|
|
} catch (const py::error_already_set &ex) {
|
|
|
|
|
// print function call stack info before release
|
|
|
|
|
std::ostringstream oss;
|
|
|
|
|
trace::TraceGraphEval();
|
|
|
|
|
trace::GetEvalStackInfo(oss);
|
|
|
|
|
// call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
|
|
|
|
|
// these info from screen, no need to open log file to find these info
|
|
|
|
|
py::print(oss.str());
|
|
|
|
|
MS_LOG(ERROR) << oss.str();
|
|
|
|
|
PynativeExecutor::GetInstance()->Clean();
|
|
|
|
|
// re-throw this exception to Python interpreter to handle it
|
|
|
|
|
throw(py::error_already_set(ex));
|
|
|
|
|
} catch (const py::type_error &ex) {
|
|
|
|
|
PynativeExecutor::GetInstance()->Clean();
|
|
|
|
|
throw py::type_error(ex);
|
|
|
|
|
} catch (const py::value_error &ex) {
|
|
|
|
|
PynativeExecutor::GetInstance()->Clean();
|
|
|
|
|
throw py::value_error(ex);
|
|
|
|
|
} catch (const py::index_error &ex) {
|
|
|
|
|
PynativeExecutor::GetInstance()->Clean();
|
|
|
|
|
throw py::index_error(ex);
|
|
|
|
|
} catch (const std::exception &ex) {
|
|
|
|
|
PynativeExecutor::GetInstance()->Clean();
|
|
|
|
|
// re-throw this exception to Python interpreter to handle it
|
|
|
|
|
throw(std::runtime_error(ex.what()));
|
|
|
|
|
} catch (...) {
|
|
|
|
|
PynativeExecutor::GetInstance()->Clean();
|
|
|
|
|
std::string exName(abi::__cxa_current_exception_type()->name());
|
|
|
|
|
MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ClearPyNativeSession() { session = nullptr; }
|
|
|
|
@ -632,7 +704,7 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); }
|
|
|
|
|
|
|
|
|
|
PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
|
|
|
|
|
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
|
|
|
|
|
auto cell_id = GetId(cell);
|
|
|
|
|
if (cell_graph_map_.count(cell_id) != 0) {
|
|
|
|
|
MS_LOG(DEBUG) << "Newgraph already compiled";
|
|
|
|
@ -753,7 +825,7 @@ void PynativeExecutor::Popp() {
|
|
|
|
|
graph_p_.pop();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
|
|
|
|
|
void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
|
|
|
|
|
auto cell_id = GetId(cell);
|
|
|
|
|
if (cell_graph_map_.count(cell_id) != 0) {
|
|
|
|
|
MS_LOG(DEBUG) << "Endgraph already compiled";
|
|
|
|
@ -892,8 +964,8 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
|
|
|
|
|
return args_spec;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
|
|
|
|
const py::args &args) {
|
|
|
|
|
void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
|
|
|
|
const py::args &args) {
|
|
|
|
|
MS_LOG(INFO) << "GradNet start" << args.size();
|
|
|
|
|
|
|
|
|
|
std::size_t size = args.size();
|
|
|
|
@ -939,8 +1011,10 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::Clear(const std::string &flag) {
|
|
|
|
|
if (flag == "resource") {
|
|
|
|
|
if (!flag.empty()) {
|
|
|
|
|
MS_LOG(INFO) << "Clear res";
|
|
|
|
|
(void)graph_map_.erase(flag);
|
|
|
|
|
(void)cell_graph_map_.erase(flag);
|
|
|
|
|
Clean();
|
|
|
|
|
// Maybe exit in the pynative runing op, so need reset pynative flag.
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
@ -949,6 +1023,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Clear";
|
|
|
|
|
top_g_ = nullptr;
|
|
|
|
|
curr_g_ = nullptr;
|
|
|
|
@ -1010,6 +1085,19 @@ FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr
|
|
|
|
|
return df_builder_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
|
|
|
|
|
PynativeExecutorTry(this, &PynativeExecutor::NewGraphInner, cell, args);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
|
|
|
|
|
PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
|
|
|
|
const py::args &args) {
|
|
|
|
|
PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
|
|
|
|
|
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
|
|
|
|
|
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
|
|
|
|
|