Fix the bug of pynative mode catching the exception.

pull/2544/head
rick_sanchez 5 years ago
parent 044b214636
commit 0a0ca3cfdc

@ -22,6 +22,7 @@
#include <unordered_set>
#include <algorithm>
#include "debug/trace.h"
#include "ir/tensor_py.h"
#include "ir/param_value_py.h"
#include "utils/any.h"
@ -66,6 +67,42 @@ PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
std::mutex PynativeExecutor::instance_lock_;
ResourcePtr PynativeExecutor::resource_;
template <typename... Args>
void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) {
try {
(executor->*method)(args...);
} catch (const py::error_already_set &ex) {
// print function call stack info before release
std::ostringstream oss;
trace::TraceGraphEval();
trace::GetEvalStackInfo(oss);
// call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
// these info from screen, no need to open log file to find these info
py::print(oss.str());
MS_LOG(ERROR) << oss.str();
PynativeExecutor::GetInstance()->Clean();
// re-throw this exception to Python interpreter to handle it
throw(py::error_already_set(ex));
} catch (const py::type_error &ex) {
PynativeExecutor::GetInstance()->Clean();
throw py::type_error(ex);
} catch (const py::value_error &ex) {
PynativeExecutor::GetInstance()->Clean();
throw py::value_error(ex);
} catch (const py::index_error &ex) {
PynativeExecutor::GetInstance()->Clean();
throw py::index_error(ex);
} catch (const std::exception &ex) {
PynativeExecutor::GetInstance()->Clean();
// re-throw this exception to Python interpreter to handle it
throw(std::runtime_error(ex.what()));
} catch (...) {
PynativeExecutor::GetInstance()->Clean();
std::string exName(abi::__cxa_current_exception_type()->name());
MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
}
}
inline ValuePtr PyAttrValue(const py::object &obj) {
ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj);
if (!converted_ret) {
@ -144,7 +181,7 @@ std::map<SignatureEnumDType, size_t> GetDstType(const py::tuple &py_args,
}
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args,
py::list *out_args_list) {
py::list *const out_args_list) {
auto &py_args = *out_args;
py::tuple input_mask(args.size());
for (size_t i = 0; i < args.size(); ++i) {
@ -564,7 +601,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
return node;
}
py::tuple RunOp(const OpExecInfoPtr &op_exec_info, const py::args &args) {
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
mindspore::parse::python_adapter::set_python_env_flag(true);
MsBackendPolicy backend_policy;
@ -603,7 +640,7 @@ py::tuple RunOp(const OpExecInfoPtr &op_exec_info, const py::args &args) {
return result;
}
py::tuple RunOp(const py::args &args) {
py::tuple RunOpInner(const py::args &args) {
MS_LOG(DEBUG) << "RunOp start" << args.size();
py::list args_input = args[PY_INPUTS];
@ -623,7 +660,42 @@ py::tuple RunOp(const py::args &args) {
return value_ret;
}
}
return RunOp(op_exec_info, args_input);
return RunOpInner(op_exec_info, args_input);
}
py::tuple RunOp(const py::args &args) {
try {
return RunOpInner(args);
} catch (const py::error_already_set &ex) {
// print function call stack info before release
std::ostringstream oss;
trace::TraceGraphEval();
trace::GetEvalStackInfo(oss);
// call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
// these info from screen, no need to open log file to find these info
py::print(oss.str());
MS_LOG(ERROR) << oss.str();
PynativeExecutor::GetInstance()->Clean();
// re-throw this exception to Python interpreter to handle it
throw(py::error_already_set(ex));
} catch (const py::type_error &ex) {
PynativeExecutor::GetInstance()->Clean();
throw py::type_error(ex);
} catch (const py::value_error &ex) {
PynativeExecutor::GetInstance()->Clean();
throw py::value_error(ex);
} catch (const py::index_error &ex) {
PynativeExecutor::GetInstance()->Clean();
throw py::index_error(ex);
} catch (const std::exception &ex) {
PynativeExecutor::GetInstance()->Clean();
// re-throw this exception to Python interpreter to handle it
throw(std::runtime_error(ex.what()));
} catch (...) {
PynativeExecutor::GetInstance()->Clean();
std::string exName(abi::__cxa_current_exception_type()->name());
MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
}
}
void ClearPyNativeSession() { session = nullptr; }
@ -632,7 +704,7 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); }
PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetId(cell);
if (cell_graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "Newgraph already compiled";
@ -753,7 +825,7 @@ void PynativeExecutor::Popp() {
graph_p_.pop();
}
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
auto cell_id = GetId(cell);
if (cell_graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "Endgraph already compiled";
@ -892,7 +964,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
return args_spec;
}
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args) {
MS_LOG(INFO) << "GradNet start" << args.size();
@ -939,8 +1011,10 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c
}
void PynativeExecutor::Clear(const std::string &flag) {
if (flag == "resource") {
if (!flag.empty()) {
MS_LOG(INFO) << "Clear res";
(void)graph_map_.erase(flag);
(void)cell_graph_map_.erase(flag);
Clean();
// Maybe exit in the pynative runing op, so need reset pynative flag.
auto ms_context = MsContext::GetInstance();
@ -949,6 +1023,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
}
return;
}
MS_LOG(INFO) << "Clear";
top_g_ = nullptr;
curr_g_ = nullptr;
@ -1010,6 +1085,19 @@ FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr
return df_builder_;
}
void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
PynativeExecutorTry(this, &PynativeExecutor::NewGraphInner, cell, args);
}
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args);
}
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args) {
PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args);
}
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")

@ -46,7 +46,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
py::tuple RunOp(const py::args &args);
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args,
py::list *out_args_list);
py::list *const out_args_list);
void ClearPyNativeSession();
@ -68,11 +68,15 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
return executor_;
}
void NewGraph(const py::object &cell, const py::args &args);
void NewGraphInner(const py::object &cell, const py::args &args);
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args);
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights);
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args);
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args);
void Clear(const std::string &flag = "");
void Clean();
void ClearRes();

@ -186,7 +186,7 @@ class Cell:
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name))
def __del__(self):
_pynative_exec.clear("resource")
_pynative_exec.clear(str(id(self)))
if hasattr(self, "_create_time"):
_executor.del_net_res(str(self._create_time))

Loading…
Cancel
Save