|
|
|
@ -237,8 +237,8 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool is_continue = false;
|
|
|
|
|
GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue),
|
|
|
|
|
"[%s] Failed to execute iteration 0.",
|
|
|
|
|
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue),
|
|
|
|
|
"[%s] Failed to execute cond-subgraph",
|
|
|
|
|
task_context.GetNodeName());
|
|
|
|
|
if (!is_continue) {
|
|
|
|
|
for (int i = 0; i < task_context.NumInputs(); ++i) {
|
|
|
|
@ -259,42 +259,28 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// backup original input tensor desc
|
|
|
|
|
std::vector<GeTensorDesc> ori_input_desc;
|
|
|
|
|
std::vector<GeTensorDesc> ori_input_desc(task_context.NumInputs());
|
|
|
|
|
for (int i = 0; i < task_context.NumInputs(); ++i) {
|
|
|
|
|
auto tensor_desc = task_context.GetInputDesc(i);
|
|
|
|
|
GE_CHECK_NOTNULL(tensor_desc);
|
|
|
|
|
ori_input_desc.emplace_back(*tensor_desc);
|
|
|
|
|
GE_CHK_STATUS_RET_NOLOG(task_context.GetInputDesc(i, ori_input_desc[i]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int iteration = 1;
|
|
|
|
|
while (true) {
|
|
|
|
|
int iteration = 0;
|
|
|
|
|
while (is_continue) {
|
|
|
|
|
++iteration;
|
|
|
|
|
GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration);
|
|
|
|
|
GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue),
|
|
|
|
|
"[%s] Failed to execute iteration %d.",
|
|
|
|
|
task_context.GetNodeName(),
|
|
|
|
|
iteration);
|
|
|
|
|
|
|
|
|
|
if (!is_continue) {
|
|
|
|
|
GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
++iteration;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < task_context.NumInputs(); ++i) {
|
|
|
|
|
auto input_tensor = task_context.GetInput(i);
|
|
|
|
|
auto tensor_desc = task_context.MutableInputDesc(i);
|
|
|
|
|
GE_CHECK_NOTNULL(input_tensor);
|
|
|
|
|
GE_CHECK_NOTNULL(tensor_desc);
|
|
|
|
|
// restore original input tensor desc
|
|
|
|
|
*tensor_desc = std::move(ori_input_desc[i]);
|
|
|
|
|
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_tensor));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration);
|
|
|
|
|
if (done_callback) {
|
|
|
|
|
done_callback();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < task_context.NumInputs(); ++i) {
|
|
|
|
|
GE_CHK_STATUS_RET_NOLOG(task_context.UpdateInputDesc(i, ori_input_desc[i]));
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -379,13 +365,6 @@ Status WhileOpNodeTask::MoveOutputs2Inputs(TaskContext &task_context) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const {
|
|
|
|
|
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue),
|
|
|
|
|
"[%s] Failed to execute cond-subgraph",
|
|
|
|
|
task_context.GetNodeName());
|
|
|
|
|
if (!is_continue) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName());
|
|
|
|
|
GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr),
|
|
|
|
|
"[%s] Failed to execute cond-subgraph", task_context.GetNodeName());
|
|
|
|
@ -396,6 +375,17 @@ Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_conti
|
|
|
|
|
"[%s] Failed to move outputs to inputs",
|
|
|
|
|
task_context.GetNodeName());
|
|
|
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue),
|
|
|
|
|
"[%s] Failed to execute cond-subgraph",
|
|
|
|
|
task_context.GetNodeName());
|
|
|
|
|
|
|
|
|
|
if (!is_continue) {
|
|
|
|
|
for (int i = 0; i < task_context.NumInputs(); ++i) {
|
|
|
|
|
auto input_desc = task_context.GetInput(i);
|
|
|
|
|
GE_CHECK_NOTNULL(input_desc);
|
|
|
|
|
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_desc));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|