From 14e2afe2f55bd218a95862b71d060b75be4041c7 Mon Sep 17 00:00:00 2001 From: chenfei Date: Wed, 21 Oct 2020 14:17:23 +0800 Subject: [PATCH] add arglist in execute info --- mindspore/ccsrc/pipeline/jit/base.h | 3 ++ mindspore/ccsrc/pipeline/jit/pipeline.cc | 38 +++++++++++++----------- mindspore/ccsrc/pipeline/jit/pipeline.h | 1 - 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/base.h b/mindspore/ccsrc/pipeline/jit/base.h index 33a410d661..41fbc05bcb 100644 --- a/mindspore/ccsrc/pipeline/jit/base.h +++ b/mindspore/ccsrc/pipeline/jit/base.h @@ -31,7 +31,10 @@ namespace pipeline { struct ExecutorInfo { FuncGraphPtr func_graph; ResourcePtr resource; + // The num of input data. std::size_t arg_list_size; + // The all args of graph,including input data and weight. + VectorRef arg_list; }; using ExecutorInfoPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 818a8c3586..b8e3b75e91 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -181,13 +181,6 @@ FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string &phase) { return info_[phase]->func_graph; } -std::size_t ExecutorPy::ArgListSize(const std::string &phase) { - if (info_.count(phase) == 0) { - MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); - } - return info_[phase]->arg_list_size; -} - compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string &phase) { ResourcePtr res = GetResource(phase); MS_EXCEPTION_IF_NULL(res); @@ -702,8 +695,9 @@ void Pipeline::Run() { } void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) { + MS_EXCEPTION_IF_NULL(arg_list); std::size_t size = args.size(); - + bool arg_list_inited = !arg_list->empty(); for (std::size_t i = 0; i < size; i++) { py::object arg = args[i]; auto ms_context = MsContext::GetInstance(); @@ -715,7 +709,14 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef if (!succ) { MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; } - arg_list->push_back(converted); + if (!arg_list_inited) { + arg_list->push_back(converted); + continue; + } + if (i >= arg_list->size()) { + MS_LOG(EXCEPTION) << "i:" << i << " output of range:" << arg_list->size(); + } + (*arg_list)[i] = converted; } MS_EXCEPTION_IF_NULL(res); @@ -792,20 +793,23 @@ py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { return args; } #endif - std::size_t full_arg_size = ArgListSize(phase_s); - if (size > full_arg_size) { - MS_LOG(WARNING) << "The arg num : size = " << size << ". full_arg_size = " << full_arg_size; + auto iter = info_.find(phase_s); + if (iter == info_.end()) { + MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase_s); } - VectorRef arg_list; - ProcessVmArg(args, phase_s, &arg_list); - + auto &execute_info = iter->second; + MS_EXCEPTION_IF_NULL(execute_info); + if (size > execute_info->arg_list_size) { + MS_LOG(WARNING) << "The arg num : size = " << size << ". full_arg_size = " << execute_info->arg_list_size; + } + ProcessVmArg(args, phase_s, &execute_info->arg_list); + // Start to run phase. compile::VmEvalFuncPtr run = GetVmEvalFunc(phase_s); if (run == nullptr) { MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase_s; } - MS_LOG(DEBUG) << "Eval run" << backend; - BaseRef value = (*run)(arg_list); + BaseRef value = (*run)(execute_info->arg_list); MS_LOG(DEBUG) << "Run end"; return BaseRefToPyData(value); } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.h b/mindspore/ccsrc/pipeline/jit/pipeline.h index 53adefd0d8..3bc8281a60 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.h +++ b/mindspore/ccsrc/pipeline/jit/pipeline.h @@ -82,7 +82,6 @@ class ExecutorPy : public std::enable_shared_from_this { ResourcePtr GetResource(const std::string &phase); FuncGraphPtr GetFuncGraph(const std::string &phase); py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type); - std::size_t ArgListSize(const std::string &phase); compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase); bool HasCompiled(const std::string &phase) const;