|
|
@ -738,27 +738,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|
|
|
inputs.emplace_back(input_node);
|
|
|
|
inputs.emplace_back(input_node);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
(*args_spec_list).emplace_back(CheckConstValue(prim, obj, abs, id, i));
|
|
|
|
auto const_input_index = prim->get_const_input_indexes();
|
|
|
|
|
|
|
|
bool have_const_input = !const_input_index.empty();
|
|
|
|
|
|
|
|
bool is_const_prim = prim->is_const_prim();
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
|
|
|
|
|
|
|
|
<< prim->is_const_prim();
|
|
|
|
|
|
|
|
bool is_const_input =
|
|
|
|
|
|
|
|
have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
|
|
|
|
|
|
|
|
if (abs == nullptr || is_const_prim || is_const_input) {
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "MakeCnode get node no in map " << id;
|
|
|
|
|
|
|
|
ValuePtr input_value = PyAttrValue(obj);
|
|
|
|
|
|
|
|
abs = input_value->ToAbstract();
|
|
|
|
|
|
|
|
if (!is_const_prim && !is_const_input) {
|
|
|
|
|
|
|
|
auto config = abstract::AbstractBase::kBroadenTensorOnly;
|
|
|
|
|
|
|
|
abs = abs->Broaden(config);
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
node_abs_map_[id] = abs;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
(*args_spec_list).emplace_back(abs);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
CNodePtr cnode = nullptr;
|
|
|
|
CNodePtr cnode = nullptr;
|
|
|
@ -770,6 +750,34 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|
|
|
return cnode;
|
|
|
|
return cnode;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
abstract::AbstractBasePtr PynativeExecutor::CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
|
|
|
|
|
|
|
const abstract::AbstractBasePtr &abs, const std::string &id,
|
|
|
|
|
|
|
|
size_t index) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
|
|
|
auto const_input_index = prim->get_const_input_indexes();
|
|
|
|
|
|
|
|
bool have_const_input = !const_input_index.empty();
|
|
|
|
|
|
|
|
bool is_const_prim = prim->is_const_prim();
|
|
|
|
|
|
|
|
auto new_abs = abs;
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
|
|
|
|
|
|
|
|
<< prim->is_const_prim();
|
|
|
|
|
|
|
|
bool is_const_input =
|
|
|
|
|
|
|
|
have_const_input && std::find(const_input_index.begin(), const_input_index.end(), index) != const_input_index.end();
|
|
|
|
|
|
|
|
if (abs == nullptr || is_const_prim || is_const_input) {
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "MakeCnode get node no in map " << id;
|
|
|
|
|
|
|
|
ValuePtr input_value = PyAttrValue(obj);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_value);
|
|
|
|
|
|
|
|
new_abs = input_value->ToAbstract();
|
|
|
|
|
|
|
|
if (!is_const_prim && !is_const_input) {
|
|
|
|
|
|
|
|
auto config = abstract::AbstractBase::kBroadenTensorOnly;
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_abs);
|
|
|
|
|
|
|
|
new_abs = new_abs->Broaden(config);
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
node_abs_map_[id] = new_abs;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return new_abs;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
|
|
|
|
void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
|
|
|
|
const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) {
|
|
|
|
const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) {
|
|
|
|
MS_EXCEPTION_IF_NULL(is_find);
|
|
|
|
MS_EXCEPTION_IF_NULL(is_find);
|
|
|
@ -1004,6 +1012,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
|
|
|
|
return free_param;
|
|
|
|
return free_param;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
node = graph_info->node_map.at(obj_id).first;
|
|
|
|
node = graph_info->node_map.at(obj_id).first;
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
MS_LOG(DEBUG) << "Get input param node " << node->ToString() << " " << obj_id;
|
|
|
|
MS_LOG(DEBUG) << "Get input param node " << node->ToString() << " " << obj_id;
|
|
|
|
return node;
|
|
|
|
return node;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -2008,9 +2017,14 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
|
|
|
|
top_cell_id_ = cell_id;
|
|
|
|
top_cell_id_ = cell_id;
|
|
|
|
in_grad_process_ = true;
|
|
|
|
in_grad_process_ = true;
|
|
|
|
// update forward already run flag with previous top cell
|
|
|
|
// update forward already run flag with previous top cell
|
|
|
|
|
|
|
|
std::string input_args_id;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < args.size(); ++i) {
|
|
|
|
|
|
|
|
input_args_id = input_args_id + GetId(args[i]) + "_";
|
|
|
|
|
|
|
|
}
|
|
|
|
auto pre_top_cell = GetTopCell(cell_id);
|
|
|
|
auto pre_top_cell = GetTopCell(cell_id);
|
|
|
|
if (pre_top_cell != nullptr) {
|
|
|
|
if (pre_top_cell != nullptr) {
|
|
|
|
pre_top_cell->forward_already_run = true;
|
|
|
|
pre_top_cell->forward_already_run = true;
|
|
|
|
|
|
|
|
pre_top_cell->input_args_id = input_args_id;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto df_builder = std::make_shared<FuncGraph>();
|
|
|
|
auto df_builder = std::make_shared<FuncGraph>();
|
|
|
|
auto graph_info = std::make_shared<GraphInfo>(cell_id);
|
|
|
|
auto graph_info = std::make_shared<GraphInfo>(cell_id);
|
|
|
@ -2019,6 +2033,7 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
|
|
|
|
resource->results()[pipeline::kPynativeGraphId] = graph_id_++;
|
|
|
|
resource->results()[pipeline::kPynativeGraphId] = graph_id_++;
|
|
|
|
auto top_cell_info = std::make_shared<TopCellInfo>(true, resource, df_builder, cell_id);
|
|
|
|
auto top_cell_info = std::make_shared<TopCellInfo>(true, resource, df_builder, cell_id);
|
|
|
|
top_cell_info->forward_already_run = true;
|
|
|
|
top_cell_info->forward_already_run = true;
|
|
|
|
|
|
|
|
top_cell_info->input_args_id = input_args_id;
|
|
|
|
if (!IsTopestGraph(cell_id)) {
|
|
|
|
if (!IsTopestGraph(cell_id)) {
|
|
|
|
top_cell_info->top_cell_index = cell_graph_list_.size();
|
|
|
|
top_cell_info->top_cell_index = cell_graph_list_.size();
|
|
|
|
top_cell_index_ = top_cell_info->top_cell_index;
|
|
|
|
top_cell_index_ = top_cell_info->top_cell_index;
|
|
|
@ -2862,11 +2877,24 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) {
|
|
|
|
py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) {
|
|
|
|
|
|
|
|
bool forward_run = false;
|
|
|
|
const auto &cell_id = GetCellId(cell, args);
|
|
|
|
const auto &cell_id = GetCellId(cell, args);
|
|
|
|
|
|
|
|
// Checkout whether top cell has already run.
|
|
|
|
|
|
|
|
std::string input_args_id;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < args.size(); ++i) {
|
|
|
|
|
|
|
|
input_args_id = input_args_id + GetId(args[i]) + "_";
|
|
|
|
|
|
|
|
}
|
|
|
|
auto top_cell = GetTopCell(cell_id);
|
|
|
|
auto top_cell = GetTopCell(cell_id);
|
|
|
|
bool forward_run = false;
|
|
|
|
|
|
|
|
if (top_cell != nullptr) {
|
|
|
|
if (top_cell != nullptr) {
|
|
|
|
forward_run = top_cell->forward_already_run;
|
|
|
|
if (!top_cell->input_args_id.empty() && top_cell->input_args_id != input_args_id && top_cell->forward_already_run &&
|
|
|
|
|
|
|
|
CheckDynamicCell(cell_id)) {
|
|
|
|
|
|
|
|
MS_LOG(WARNING) << "The construct of running cell is dynamic and the input info of this cell has changed, "
|
|
|
|
|
|
|
|
"forward process will run again";
|
|
|
|
|
|
|
|
top_cell->forward_already_run = false;
|
|
|
|
|
|
|
|
top_cell->input_args_id = input_args_id;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
forward_run = top_cell->forward_already_run;
|
|
|
|
|
|
|
|
}
|
|
|
|
if (forward_run) {
|
|
|
|
if (forward_run) {
|
|
|
|
top_cell_index_ = top_cell->top_cell_index;
|
|
|
|
top_cell_index_ = top_cell->top_cell_index;
|
|
|
|
}
|
|
|
|
}
|
|
|
|