|
|
|
@ -323,12 +323,6 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
|
|
|
|
}
|
|
|
|
|
op_exec_info->py_primitive = prim;
|
|
|
|
|
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
|
|
|
|
auto inst = PynativeExecutor::GetInstance();
|
|
|
|
|
if (inst->grad_flag()) {
|
|
|
|
|
op_exec_info->value = inst->GetForwardValue(op_exec_info);
|
|
|
|
|
} else {
|
|
|
|
|
(void)GetOpId(op_exec_info);
|
|
|
|
|
}
|
|
|
|
|
op_exec_info->op_inputs = args[PY_INPUTS];
|
|
|
|
|
ConvertInputs(prim, args[PY_INPUTS], op_exec_info);
|
|
|
|
|
return op_exec_info;
|
|
|
|
@ -903,6 +897,12 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
|
|
|
|
|
MS_LOG(DEBUG) << "set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (PynativeExecutor::GetInstance()->grad_flag()) {
|
|
|
|
|
op_exec_info->value = PynativeExecutor::GetInstance()->GetForwardValue(op_exec_info);
|
|
|
|
|
} else {
|
|
|
|
|
(void)GetOpId(op_exec_info);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto result = RunOpInner(op_exec_info);
|
|
|
|
|
py::object out_real = result;
|
|
|
|
|
if (result.size() == 1) {
|
|
|
|
|