|
|
|
@ -115,12 +115,12 @@ inline ValuePtr PyAttrValue(const py::object &obj) {
|
|
|
|
|
static std::string GetId(const py::object &obj) {
|
|
|
|
|
py::object to_process = obj;
|
|
|
|
|
std::string prefix = "";
|
|
|
|
|
if (py::isinstance<py::tuple>(to_process)) {
|
|
|
|
|
if (py::isinstance<py::tuple>(to_process) || py::isinstance<py::list>(to_process)) {
|
|
|
|
|
auto p_list = py::cast<py::tuple>(to_process);
|
|
|
|
|
if (p_list.size() == 0) {
|
|
|
|
|
if (p_list.empty()) {
|
|
|
|
|
return "empty";
|
|
|
|
|
}
|
|
|
|
|
prefix = "tuple:";
|
|
|
|
|
prefix = py::isinstance<py::tuple>(to_process) ? "tuple:" : "list";
|
|
|
|
|
std::string key = "";
|
|
|
|
|
for (size_t i = 0; i < p_list.size(); ++i) {
|
|
|
|
|
key += std::string(py::str(GetId(p_list[i]))) + ":";
|
|
|
|
@ -738,6 +738,21 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
|
|
|
|
|
return node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) {
|
|
|
|
|
auto cell_id = GetId(cell);
|
|
|
|
|
for (size_t i = 0; i < args.size(); i++) {
|
|
|
|
|
std::string arg_id = GetId(args[i]);
|
|
|
|
|
if (node_abs_map_.find(arg_id) != node_abs_map_.end()) {
|
|
|
|
|
cell_id += node_abs_map_[arg_id]->ToString();
|
|
|
|
|
} else {
|
|
|
|
|
AbstractBasePtr abs = abstract::FromValueInside(PyAttrValue(args[i]), true);
|
|
|
|
|
cell_id += abs->ToString();
|
|
|
|
|
node_abs_map_[arg_id] = abs;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return cell_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
|
|
|
|
|
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
|
|
|
|
|
mindspore::parse::python_adapter::set_python_env_flag(true);
|
|
|
|
@ -785,8 +800,8 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
|
|
|
|
|
}
|
|
|
|
|
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list);
|
|
|
|
|
bool is_find = false;
|
|
|
|
|
if (prim_abs_list.find(prim->id()) != prim_abs_list.end()) {
|
|
|
|
|
auto abs_list = prim_abs_list[prim->id()];
|
|
|
|
|
if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
|
|
|
|
|
auto abs_list = prim_abs_list_[prim->id()];
|
|
|
|
|
MS_LOG(DEBUG) << "match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
|
|
|
|
|
if (abs_list.find(args_spec_list) != abs_list.end()) {
|
|
|
|
|
MS_LOG(DEBUG) << "match prim ok" << op_exec_info->op_name;
|
|
|
|
@ -827,7 +842,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
|
|
|
|
|
|
|
|
|
|
if (!is_find) {
|
|
|
|
|
// const_value need infer every step
|
|
|
|
|
auto &out = prim_abs_list[prim->id()];
|
|
|
|
|
auto &out = prim_abs_list_[prim->id()];
|
|
|
|
|
out[args_spec_list].abs = op_exec_info->abstract;
|
|
|
|
|
out[args_spec_list].attrs = prim->evaluate_added_attrs();
|
|
|
|
|
MS_LOG(DEBUG) << "set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
|
|
|
|
@ -890,7 +905,7 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); }
|
|
|
|
|
PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
|
|
|
|
|
auto cell_id = GetId(cell);
|
|
|
|
|
auto cell_id = GetCellId(cell, args);
|
|
|
|
|
if (cell_graph_map_.count(cell_id) != 0) {
|
|
|
|
|
if (cell_resource_map_.find(cell_id) != cell_resource_map_.end()) {
|
|
|
|
|
resource_ = cell_resource_map_[cell_id];
|
|
|
|
@ -1016,7 +1031,7 @@ void PynativeExecutor::Popp() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
|
|
|
|
|
auto cell_id = GetId(cell);
|
|
|
|
|
auto cell_id = GetCellId(cell, args);
|
|
|
|
|
if (cell_graph_map_.count(cell_id) != 0) {
|
|
|
|
|
MS_LOG(DEBUG) << "Endgraph already compiled";
|
|
|
|
|
return;
|
|
|
|
@ -1078,7 +1093,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
|
|
|
|
|
inputs.push_back(input);
|
|
|
|
|
}
|
|
|
|
|
auto out_cnode = curr_g_->NewCNode(inputs);
|
|
|
|
|
set_pyobj(curr_g_, GetId(cell));
|
|
|
|
|
set_pyobj(curr_g_, GetCellId(cell, args));
|
|
|
|
|
if (py::isinstance<py::tuple>(out)) {
|
|
|
|
|
auto out_list = py::cast<py::tuple>(out);
|
|
|
|
|
auto out_size = static_cast<int>(out_list.size());
|
|
|
|
@ -1169,7 +1184,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
|
|
|
|
|
MS_LOG(INFO) << "GradNet start" << args.size();
|
|
|
|
|
|
|
|
|
|
std::size_t size = args.size();
|
|
|
|
|
auto cell_id = GetId(cell);
|
|
|
|
|
std::string cell_id = GetCellId(cell, args);
|
|
|
|
|
if (graph_map_.count(cell_id) != 0) {
|
|
|
|
|
MS_LOG(DEBUG) << "GradNet already compiled";
|
|
|
|
|
return;
|
|
|
|
|