while loop failed to restore input desc

pull/1400/head
chuxing 4 years ago
parent f19cd2fca9
commit 4a7f623b12

@ -35,12 +35,14 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item(
node_item.NodeName().c_str(), node_item.NodeName().c_str(),
this->num_pending_shapes_); this->num_pending_shapes_);
for (int i = 0; i < node_item.num_inputs; ++i){ input_tensor_desc.resize(node_item.num_inputs);
input_tensor_desc.emplace_back(*node_item.MutableInputDesc(i)); for (int i = 0; i < node_item.num_inputs; ++i) {
node_item.GetInputDesc(i, input_tensor_desc[i]);
} }
for (int i = 0; i < node_item.num_outputs; ++i){ output_tensor_desc.resize(node_item.num_outputs);
output_tensor_desc.emplace_back(*node_item.MutableOutputDesc(i)); for (int i = 0; i < node_item.num_outputs; ++i) {
node_item.GetOutputDesc(i, output_tensor_desc[i]);
} }
} }

@ -297,7 +297,7 @@ void NodeItem::SetToDynamic() {
} }
} }
GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { GeTensorDescPtr NodeItem::DoGetInputDesc(int index) const {
if (!has_optional_inputs) { if (!has_optional_inputs) {
return op_desc->MutableInputDesc(static_cast<uint32_t>(index)); return op_desc->MutableInputDesc(static_cast<uint32_t>(index));
} }
@ -314,6 +314,40 @@ GeTensorDescPtr NodeItem::MutableInputDesc(int index) const {
return op_desc->MutableInputDesc(input_desc_indices_[index]); return op_desc->MutableInputDesc(input_desc_indices_[index]);
} }
GeTensorDescPtr NodeItem::MutableInputDesc(int index) const {
std::lock_guard<std::mutex> lk(mu_);
return DoGetInputDesc(index);
}
Status NodeItem::GetInputDesc(int index, GeTensorDesc &tensor_desc) const {
std::lock_guard<std::mutex> lk(mu_);
auto input_desc = DoGetInputDesc(index);
GE_CHECK_NOTNULL(input_desc);
tensor_desc = *input_desc;
return SUCCESS;
}
Status NodeItem::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const {
std::lock_guard<std::mutex> lk(mu_);
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(index));
GE_CHECK_NOTNULL(output_desc);
tensor_desc = *output_desc;
return SUCCESS;
}
GeTensorDescPtr NodeItem::MutableOutputDesc(int index) const {
std::lock_guard<std::mutex> lk(mu_);
return op_desc->MutableOutputDesc(static_cast<uint32_t>(index));
}
Status NodeItem::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) {
std::lock_guard<std::mutex> lk(mu_);
auto input_desc = DoGetInputDesc(index);
GE_CHECK_NOTNULL(input_desc);
*input_desc = tensor_desc;
return SUCCESS;
}
Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const { Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const {
if (!has_optional_inputs) { if (!has_optional_inputs) {
canonical_index = index; canonical_index = index;

@ -17,6 +17,7 @@
#ifndef GE_HYBRID_MODEL_NODE_ITEM_H_ #ifndef GE_HYBRID_MODEL_NODE_ITEM_H_
#define GE_HYBRID_MODEL_NODE_ITEM_H_ #define GE_HYBRID_MODEL_NODE_ITEM_H_
#include <mutex>
#include <vector> #include <vector>
#include "external/ge/ge_api_error_codes.h" #include "external/ge/ge_api_error_codes.h"
#include "graph/node.h" #include "graph/node.h"
@ -57,12 +58,16 @@ struct NodeItem {
bool IsInputShapeStatic(int index) const; bool IsInputShapeStatic(int index) const;
GeTensorDescPtr MutableOutputDesc(int index) const { GeTensorDescPtr MutableOutputDesc(int index) const;
return op_desc->MutableOutputDesc(static_cast<uint32_t>(index));
} Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc);
GeTensorDescPtr MutableInputDesc(int index) const; GeTensorDescPtr MutableInputDesc(int index) const;
Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const;
Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const;
Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const; Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const;
bool IsControlOp() const; bool IsControlOp() const;
@ -113,9 +118,11 @@ struct NodeItem {
Status ResolveDynamicState(); Status ResolveDynamicState();
Status ResolveStaticInputsAndOutputs(); Status ResolveStaticInputsAndOutputs();
void ResolveUnknownShapeType(); void ResolveUnknownShapeType();
GeTensorDescPtr DoGetInputDesc(int index) const;
std::vector<bool> is_input_shape_static_; std::vector<bool> is_input_shape_static_;
std::vector<uint32_t> input_desc_indices_; std::vector<uint32_t> input_desc_indices_;
mutable std::mutex mu_;
}; };
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge

@ -237,8 +237,8 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun
} }
bool is_continue = false; bool is_continue = false;
GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue),
"[%s] Failed to execute iteration 0.", "[%s] Failed to execute cond-subgraph",
task_context.GetNodeName()); task_context.GetNodeName());
if (!is_continue) { if (!is_continue) {
for (int i = 0; i < task_context.NumInputs(); ++i) { for (int i = 0; i < task_context.NumInputs(); ++i) {
@ -259,42 +259,28 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun
} }
// backup original input tensor desc // backup original input tensor desc
std::vector<GeTensorDesc> ori_input_desc; std::vector<GeTensorDesc> ori_input_desc(task_context.NumInputs());
for (int i = 0; i < task_context.NumInputs(); ++i) { for (int i = 0; i < task_context.NumInputs(); ++i) {
auto tensor_desc = task_context.GetInputDesc(i); GE_CHK_STATUS_RET_NOLOG(task_context.GetInputDesc(i, ori_input_desc[i]));
GE_CHECK_NOTNULL(tensor_desc);
ori_input_desc.emplace_back(*tensor_desc);
} }
int iteration = 1; int iteration = 0;
while (true) { while (is_continue) {
++iteration;
GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration); GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration);
GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue),
"[%s] Failed to execute iteration %d.", "[%s] Failed to execute iteration %d.",
task_context.GetNodeName(), task_context.GetNodeName(),
iteration); iteration);
if (!is_continue) {
GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration);
break;
}
++iteration;
} }
GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration);
for (int i = 0; i < task_context.NumInputs(); ++i) {
auto input_tensor = task_context.GetInput(i);
auto tensor_desc = task_context.MutableInputDesc(i);
GE_CHECK_NOTNULL(input_tensor);
GE_CHECK_NOTNULL(tensor_desc);
// restore original input tensor desc
*tensor_desc = std::move(ori_input_desc[i]);
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_tensor));
}
if (done_callback) { if (done_callback) {
done_callback(); done_callback();
} }
for (int i = 0; i < task_context.NumInputs(); ++i) {
GE_CHK_STATUS_RET_NOLOG(task_context.UpdateInputDesc(i, ori_input_desc[i]));
}
return SUCCESS; return SUCCESS;
} }
@ -379,13 +365,6 @@ Status WhileOpNodeTask::MoveOutputs2Inputs(TaskContext &task_context) {
} }
Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const { Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const {
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue),
"[%s] Failed to execute cond-subgraph",
task_context.GetNodeName());
if (!is_continue) {
return SUCCESS;
}
GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName()); GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName());
GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr), GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr),
"[%s] Failed to execute cond-subgraph", task_context.GetNodeName()); "[%s] Failed to execute cond-subgraph", task_context.GetNodeName());
@ -396,6 +375,17 @@ Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_conti
"[%s] Failed to move outputs to inputs", "[%s] Failed to move outputs to inputs",
task_context.GetNodeName()); task_context.GetNodeName());
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue),
"[%s] Failed to execute cond-subgraph",
task_context.GetNodeName());
if (!is_continue) {
for (int i = 0; i < task_context.NumInputs(); ++i) {
auto input_desc = task_context.GetInput(i);
GE_CHECK_NOTNULL(input_desc);
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_desc));
}
}
return SUCCESS; return SUCCESS;
} }

@ -80,7 +80,6 @@ class WhileOpNodeTask : public ControlOpNodeTask {
Status ExecuteCond(TaskContext &task_context, bool &is_continue) const; Status ExecuteCond(TaskContext &task_context, bool &is_continue) const;
static Status MoveOutputs2Inputs(TaskContext &task_context); static Status MoveOutputs2Inputs(TaskContext &task_context);
Status ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const; Status ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const;
private: private:

@ -554,5 +554,16 @@ NodeState *TaskContext::GetNodeState() const {
return node_state_; return node_state_;
} }
Status TaskContext::GetInputDesc(int index, GeTensorDesc &tensor_desc) const {
return node_item_->GetInputDesc(index, tensor_desc);
}
Status TaskContext::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) {
return const_cast<NodeItem *>(node_item_)->UpdateInputDesc(index, tensor_desc);
}
Status TaskContext::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const {
return node_item_->GetOutputDesc(index, tensor_desc);
}
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge

@ -50,9 +50,12 @@ class TaskContext {
const char *GetNodeName() const; const char *GetNodeName() const;
TensorValue *MutableInput(int index); TensorValue *MutableInput(int index);
ConstGeTensorDescPtr GetInputDesc(int index) const; ConstGeTensorDescPtr GetInputDesc(int index) const;
Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const;
ConstGeTensorDescPtr GetOutputDesc(int index) const; ConstGeTensorDescPtr GetOutputDesc(int index) const;
Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const;
GeTensorDescPtr MutableInputDesc(int index) const; GeTensorDescPtr MutableInputDesc(int index) const;
GeTensorDescPtr MutableOutputDesc(int index) const; GeTensorDescPtr MutableOutputDesc(int index) const;
Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc);
void ReleaseInputsAndOutputs(); void ReleaseInputsAndOutputs();
bool NeedCallback(); bool NeedCallback();
void ReleaseInput(int index); void ReleaseInput(int index);

Loading…
Cancel
Save