add arglist in execute info

pull/7556/head
chenfei 4 years ago
parent 688c7c104b
commit 14e2afe2f5

@ -31,7 +31,10 @@ namespace pipeline {
struct ExecutorInfo {
FuncGraphPtr func_graph;
ResourcePtr resource;
// The num of input data.
std::size_t arg_list_size;
// The all args of graph,including input data and weight.
VectorRef arg_list;
};
using ExecutorInfoPtr = std::shared_ptr<ExecutorInfo>;

@ -181,13 +181,6 @@ FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string &phase) {
return info_[phase]->func_graph;
}
std::size_t ExecutorPy::ArgListSize(const std::string &phase) {
if (info_.count(phase) == 0) {
MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase);
}
return info_[phase]->arg_list_size;
}
compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string &phase) {
ResourcePtr res = GetResource(phase);
MS_EXCEPTION_IF_NULL(res);
@ -702,8 +695,9 @@ void Pipeline::Run() {
}
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) {
MS_EXCEPTION_IF_NULL(arg_list);
std::size_t size = args.size();
bool arg_list_inited = !arg_list->empty();
for (std::size_t i = 0; i < size; i++) {
py::object arg = args[i];
auto ms_context = MsContext::GetInstance();
@ -715,7 +709,14 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef
if (!succ) {
MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed.";
}
arg_list->push_back(converted);
if (!arg_list_inited) {
arg_list->push_back(converted);
continue;
}
if (i >= arg_list->size()) {
MS_LOG(EXCEPTION) << "i:" << i << " output of range:" << arg_list->size();
}
(*arg_list)[i] = converted;
}
MS_EXCEPTION_IF_NULL(res);
@ -792,20 +793,23 @@ py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) {
return args;
}
#endif
std::size_t full_arg_size = ArgListSize(phase_s);
if (size > full_arg_size) {
MS_LOG(WARNING) << "The arg num : size = " << size << ". full_arg_size = " << full_arg_size;
auto iter = info_.find(phase_s);
if (iter == info_.end()) {
MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase_s);
}
VectorRef arg_list;
ProcessVmArg(args, phase_s, &arg_list);
auto &execute_info = iter->second;
MS_EXCEPTION_IF_NULL(execute_info);
if (size > execute_info->arg_list_size) {
MS_LOG(WARNING) << "The arg num : size = " << size << ". full_arg_size = " << execute_info->arg_list_size;
}
ProcessVmArg(args, phase_s, &execute_info->arg_list);
// Start to run phase.
compile::VmEvalFuncPtr run = GetVmEvalFunc(phase_s);
if (run == nullptr) {
MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase_s;
}
MS_LOG(DEBUG) << "Eval run" << backend;
BaseRef value = (*run)(arg_list);
BaseRef value = (*run)(execute_info->arg_list);
MS_LOG(DEBUG) << "Run end";
return BaseRefToPyData(value);
}

@ -82,7 +82,6 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
ResourcePtr GetResource(const std::string &phase);
FuncGraphPtr GetFuncGraph(const std::string &phase);
py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type);
std::size_t ArgListSize(const std::string &phase);
compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase);
bool HasCompiled(const std::string &phase) const;

Loading…
Cancel
Save