fix the bug for op when return a single element tuple in PyNative mode

pull/10672/head
buxue 5 years ago
parent 237faca57e
commit 2739b6d5e5

@ -523,7 +523,7 @@ PynativeExecutor::~PynativeExecutor() {
ClearRes();
}
py::tuple RunOp(const py::args &args) {
py::object RunOp(const py::args &args) {
auto executor = PynativeExecutor::GetInstance();
MS_EXCEPTION_IF_NULL(executor);
OpExecInfoPtr op_exec_info = executor->GenerateOpExecInfo(args);
@ -555,10 +555,10 @@ py::tuple RunOp(const py::args &args) {
}
}
py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
py::object PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
MS_EXCEPTION_IF_NULL(op_exec_info);
if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
return RunOpWithInitBackendPolicy(op_exec_info);
return RunOpWithInitBackendPolicy(op_exec_info)[0];
}
// make cnode for building grad graph if grad flag is set.
abstract::AbstractBasePtrList args_spec_list;
@ -574,14 +574,10 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
MS_EXCEPTION_IF_NULL(prim);
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
if (!output["value"].is_none()) {
py::tuple value_ret(1);
value_ret[0] = output["value"];
return value_ret;
return output["value"];
}
if (prim->is_const_prim()) {
py::tuple value_ret(1);
value_ret[0] = "";
return value_ret;
return py::cast("");
}
// add output abstract info into cache
if (!is_find && !op_exec_info->is_dynamic_shape) {
@ -593,10 +589,12 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
}
// run op with selected backend
auto result = RunOpWithInitBackendPolicy(op_exec_info);
py::object out_real = result;
if (result.size() == 1) {
MS_LOG(DEBUG) << "Output size is 1";
py::object out_real;
if (result.size() == 1 && op_exec_info->abstract != nullptr &&
!op_exec_info->abstract->isa<abstract::AbstractSequeue>()) {
out_real = result[0];
} else {
out_real = result;
}
// update output abstract for cnode
if (cnode != nullptr) {
@ -609,7 +607,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
SaveAllResult(op_exec_info, cnode, out_real);
// Update the abstract and device address of value node with tensor in grad graph
UpdateAbstractAndDeviceAddress(op_exec_info, out_real);
return result;
return out_real;
}
OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
@ -788,7 +786,7 @@ py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &typ
op_exec->is_mixed_precision_cast = true;
op_exec->next_op_name = op_name;
op_exec->next_input_index = index;
return RunOpInner(op_exec)[0];
return RunOpInner(op_exec);
}
py::object PynativeExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::object obj, const std::string &op_name,
@ -1249,11 +1247,9 @@ py::tuple PynativeExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_e
auto backend_policy = InitEnv(op_exec_info);
PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
// returns a null py::tuple on error
py::tuple err_ret(0);
py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
if (status != PYNATIVE_SUCCESS) {
MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name;
return err_ret;
MS_LOG(EXCEPTION) << "Failed to run " << op_exec_info->op_name;
}
MS_LOG(DEBUG) << "RunOp end";
@ -1361,16 +1357,18 @@ py::object PynativeExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, Pynati
auto primitive = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(primitive);
auto result = primitive->RunPyComputeFunction(op_inputs);
MS_LOG(INFO) << "RunOpInVM end";
if (py::isinstance<py::none>(result)) {
MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
*status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
py::tuple err_ret(0);
return std::move(err_ret);
}
// execute op
py::tuple tuple_result = py::make_tuple(result);
*status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "RunOpInVM end";
if (py::isinstance<py::tuple>(result)) {
return result;
}
py::tuple tuple_result = py::make_tuple(result);
return std::move(tuple_result);
}

@ -55,7 +55,7 @@ using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAb
using OpIndexWithTensorId = std::unordered_map<std::string, std::vector<std::string>>;
using TensorIdWithTensor = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
py::tuple RunOp(const py::args &args);
py::object RunOp(const py::args &args);
void ClearPyNativeSession();
@ -114,7 +114,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void EnterConstruct(const py::object &cell);
void LeaveConstruct(const py::object &cell);
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);
py::object RunOpInner(const OpExecInfoPtr &op_exec_info);
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
void NewGraph(const py::object &cell, const py::args &args);
py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase);

@ -510,8 +510,4 @@ def constexpr(fn=None, get_instance=True, name=None):
def _run_op(obj, op_name, args):
"""Single op execution function supported by ge in PyNative mode."""
output = real_run_op(obj, op_name, args)
if not output:
raise RuntimeError("Pynative run op %s failed!" % op_name)
if len(output) == 1:
output = output[0]
return output

Loading…
Cancel
Save