|
|
|
@ -467,6 +467,11 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t>
|
|
|
|
|
|
|
|
|
|
opt::ConstInputToAttrInfoRegister reg;
|
|
|
|
|
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®);
|
|
|
|
|
if (op_run_info->is_dynamic_shape &&
|
|
|
|
|
dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) {
|
|
|
|
|
MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name;
|
|
|
|
|
reg_exist = false;
|
|
|
|
|
}
|
|
|
|
|
if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
|
|
|
|
|
reg_exist = false;
|
|
|
|
|
}
|
|
|
|
@ -594,6 +599,7 @@ py::tuple RunOp(const py::args &args) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
|
|
|
|
if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
|
|
|
|
|
return RunOpWithInitBackendPolicy(op_exec_info);
|
|
|
|
|
}
|
|
|
|
@ -604,58 +610,27 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
|
|
|
|
|
op_exec_info->inputs_mask = op_masks;
|
|
|
|
|
// get output abstract info
|
|
|
|
|
bool is_find = false;
|
|
|
|
|
GetOpOutputAbstract(op_exec_info, args_spec_list, &is_find);
|
|
|
|
|
MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString();
|
|
|
|
|
// infer output value for const prim
|
|
|
|
|
auto prim = op_exec_info->py_primitive;
|
|
|
|
|
if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
|
|
|
|
|
auto abs_list = prim_abs_list_[prim->id()];
|
|
|
|
|
MS_LOG(DEBUG) << "Match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
|
|
|
|
|
if (abs_list.find(args_spec_list) != abs_list.end()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Match prim ok " << op_exec_info->op_name;
|
|
|
|
|
op_exec_info->abstract = abs_list[args_spec_list].abs;
|
|
|
|
|
op_exec_info->is_dynamic_shape = abs_list[args_spec_list].is_dynamic_shape;
|
|
|
|
|
prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs);
|
|
|
|
|
is_find = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_exec_info->op_name) != force_infer_prim.end()) {
|
|
|
|
|
// use python infer method
|
|
|
|
|
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
|
|
|
|
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list);
|
|
|
|
|
}
|
|
|
|
|
// get output dynamic shape info
|
|
|
|
|
auto abstract = op_exec_info->abstract;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract);
|
|
|
|
|
auto shape = abstract->BuildShape();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(shape);
|
|
|
|
|
auto shape_info = shape->ToString();
|
|
|
|
|
if (shape_info.find("-1") != string::npos) {
|
|
|
|
|
op_exec_info->is_dynamic_shape = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (cnode != nullptr) {
|
|
|
|
|
cnode->set_abstract(op_exec_info->abstract);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
|
|
|
|
|
if (!output["value"].is_none()) {
|
|
|
|
|
py::tuple value_ret(1);
|
|
|
|
|
value_ret[0] = output["value"];
|
|
|
|
|
return value_ret;
|
|
|
|
|
}
|
|
|
|
|
// infer output value for const prim
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
|
|
|
|
if (op_exec_info->abstract != nullptr) {
|
|
|
|
|
MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString();
|
|
|
|
|
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
|
|
|
|
|
if (!output["value"].is_none()) {
|
|
|
|
|
py::tuple value_ret(1);
|
|
|
|
|
value_ret[0] = output["value"];
|
|
|
|
|
return value_ret;
|
|
|
|
|
}
|
|
|
|
|
if (op_exec_info->py_primitive->is_const_prim()) {
|
|
|
|
|
py::tuple value_ret(1);
|
|
|
|
|
value_ret[0] = "";
|
|
|
|
|
return value_ret;
|
|
|
|
|
}
|
|
|
|
|
if (prim->is_const_prim()) {
|
|
|
|
|
py::tuple value_ret(1);
|
|
|
|
|
value_ret[0] = "";
|
|
|
|
|
return value_ret;
|
|
|
|
|
}
|
|
|
|
|
// add output abstract info into cache
|
|
|
|
|
if (!is_find) {
|
|
|
|
|
if (!is_find && !op_exec_info->is_dynamic_shape) {
|
|
|
|
|
// const_value need infer every step
|
|
|
|
|
auto &out = prim_abs_list_[prim->id()];
|
|
|
|
|
out[args_spec_list].abs = op_exec_info->abstract;
|
|
|
|
|
out[args_spec_list].is_dynamic_shape = op_exec_info->is_dynamic_shape;
|
|
|
|
|
out[args_spec_list].attrs = prim->evaluate_added_attrs();
|
|
|
|
|
MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
|
|
|
|
|
}
|
|
|
|
@ -666,8 +641,13 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
|
|
|
|
|
MS_LOG(DEBUG) << "Output size is 1";
|
|
|
|
|
out_real = result[0];
|
|
|
|
|
}
|
|
|
|
|
// update output abstract for cnode
|
|
|
|
|
if (cnode != nullptr) {
|
|
|
|
|
cnode->set_abstract(op_exec_info->abstract);
|
|
|
|
|
}
|
|
|
|
|
std::string obj_id = GetId(out_real);
|
|
|
|
|
node_abs_map_[obj_id] = op_exec_info->abstract;
|
|
|
|
|
// save info for building grad graph
|
|
|
|
|
SaveOutputNodeMap(obj_id, out_real, cnode);
|
|
|
|
|
SaveAllResult(op_exec_info, cnode, out_real);
|
|
|
|
|
// Update the abstract and device address of value node with tensor in grad graph
|
|
|
|
@ -784,6 +764,49 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|
|
|
|
return cnode;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
|
|
|
|
|
const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(is_find);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
|
|
|
|
*is_find = false;
|
|
|
|
|
auto op_name = op_exec_info->op_name;
|
|
|
|
|
auto prim = op_exec_info->py_primitive;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
|
|
|
|
|
auto abs_list = prim_abs_list_[prim->id()];
|
|
|
|
|
MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
|
|
|
|
|
if (abs_list.find(args_spec_list) != abs_list.end()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Match prim ok " << op_name;
|
|
|
|
|
op_exec_info->abstract = abs_list[args_spec_list].abs;
|
|
|
|
|
prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs);
|
|
|
|
|
*is_find = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) {
|
|
|
|
|
// use python infer method
|
|
|
|
|
if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) {
|
|
|
|
|
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// get output dynamic shape info
|
|
|
|
|
auto py_abstract = op_exec_info->abstract;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(py_abstract);
|
|
|
|
|
auto py_shape = py_abstract->BuildShape();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(py_shape);
|
|
|
|
|
auto py_shape_info = py_shape->ToString();
|
|
|
|
|
if (py_shape_info.find("-1") != string::npos) {
|
|
|
|
|
auto c_abstract = abstract::CppInferShape(prim, args_spec_list);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_abstract);
|
|
|
|
|
auto c_shape = c_abstract->BuildShape();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_shape);
|
|
|
|
|
auto c_shape_info = c_shape->ToString();
|
|
|
|
|
MS_LOG(DEBUG) << "Final infer output shape: " << c_shape_info;
|
|
|
|
|
if (c_shape_info.find("-1") != string::npos) {
|
|
|
|
|
op_exec_info->is_dynamic_shape = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
|
|
|
|
|
size_t index) {
|
|
|
|
|
py::tuple cast_args(3);
|
|
|
|
@ -1326,6 +1349,9 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati
|
|
|
|
|
op_exec_info->next_input_index};
|
|
|
|
|
VectorRef outputs;
|
|
|
|
|
session->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask);
|
|
|
|
|
if (op_exec_info->is_dynamic_shape) {
|
|
|
|
|
op_exec_info->abstract = op_run_info.abstract;
|
|
|
|
|
}
|
|
|
|
|
auto result = BaseRefToPyData(outputs);
|
|
|
|
|
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
|
|
|
|
|
*status = PYNATIVE_SUCCESS;
|
|
|
|
|