You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/hybrid/executor/subgraph_context.cc

145 lines
5.2 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "subgraph_context.h"
#include "common/debug/log.h"
#include "hybrid/executor/hybrid_model_executor.h"
namespace ge {
namespace hybrid {
SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context)
: graph_item_(graph_item), execution_context_(execution_context) {
}
Status SubgraphContext::Init() {
GE_CHECK_NOTNULL(graph_item_);
GELOGD("[%s] Start to init subgraph context. total inputs = %d, total outputs = %d",
graph_item_->GetName().c_str(),
graph_item_->TotalInputs(),
graph_item_->TotalOutputs());
all_inputs_.resize(static_cast<unsigned long>(graph_item_->TotalInputs()));
all_outputs_.resize(static_cast<unsigned long>(graph_item_->TotalOutputs()));
return SUCCESS;
}
NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) {
std::lock_guard<std::mutex> lk(mu_);
auto &node_state = node_states_[node_item];
if (node_state == nullptr) {
node_state.reset(new(std::nothrow)NodeState(*node_item, this));
}
return node_state;
}
Status SubgraphContext::SetInput(int index, const TensorValue &tensor) {
if (static_cast<size_t>(index) >= all_inputs_.size()) {
GELOGE(INTERNAL_ERROR,
"[Check][Param:index]input index out of range. all input num = %zu, input index = %d",
all_inputs_.size(), index);
REPORT_INNER_ERROR("E19999",
"input param index out of range when SubgraphContext %s, all input num = %zu, input index = %d.",
__FUNCTION__, all_inputs_.size(), index);
return INTERNAL_ERROR;
}
all_inputs_[index] = tensor;
return SUCCESS;
}
Status SubgraphContext::SetInput(const NodeItem &node_item, int input_index, const TensorValue &tensor) {
auto index = node_item.input_start + input_index;
return SetInput(index, tensor);
}
Status SubgraphContext::SetOutput(const NodeItem &node_item, int output_index, const TensorValue &tensor) {
auto index = node_item.output_start + output_index;
if ((output_index >= node_item.num_outputs) || (static_cast<size_t>(index) >= all_outputs_.size())) {
GELOGE(INTERNAL_ERROR,
"[Check][Param:output_index]output index out of range. all output num = %zu, node_item = %s,"
"output index = %d.", all_outputs_.size(), node_item.DebugString().c_str(), output_index);
REPORT_INNER_ERROR("E19999", "output index out of range when SubgraphContext %s. "
"all output num = %zu, node_item = %s, output index = %d.",
__FUNCTION__, all_outputs_.size(), node_item.DebugString().c_str(), output_index);
return INTERNAL_ERROR;
}
all_outputs_[index] = tensor;
return SUCCESS;
}
Status SubgraphContext::GetInput(int index, TensorValue &tensor) {
GE_CHECK_GE(all_inputs_.size(), index + 1U);
tensor = all_inputs_[index];
return SUCCESS;
}
Status SubgraphContext::GetOutputs(std::vector<TensorValue> &outputs) {
if (graph_item_->IsDynamic()) {
GELOGD("[%s] graph is dynamic, get outputs from net output input tensors", graph_item_->GetName().c_str());
// get from net output inputs
auto output_node = graph_item_->GetOutputNode();
if (output_node != nullptr) {
for (int i = 0; i < output_node->num_inputs; ++i) {
TensorValue tensor;
GE_CHK_STATUS_RET_NOLOG(GetInput(output_node->input_start + i, tensor));
GELOGD("[%s] Adding output tensor by input index [%d], tensor = %s",
graph_item_->GetName().c_str(),
output_node->input_start + i,
tensor.DebugString().c_str());
outputs.emplace_back(std::move(tensor));
}
}
} else {
GELOGD("[%s] graph is non-dynamic, get outputs from subgraph outputs", graph_item_->GetName().c_str());
for (auto &tensor : all_outputs_) {
GELOGD("[%s] Adding output tensor: %s", graph_item_->GetName().c_str(), tensor.DebugString().c_str());
outputs.emplace_back(tensor);
}
}
return SUCCESS;
}
Status SubgraphContext::Await(const NodePtr &node) {
if (node_done_manager_.Await(node)) {
return SUCCESS;
}
if (execution_context_->is_eos_) {
return END_OF_SEQUENCE;
}
return FAILED;
}
void SubgraphContext::OnError(Status error) {
if (error != END_OF_SEQUENCE) {
GELOGE(error, "[Check][Param:error][%s] Error:%d occurred while executing graph.",
graph_item_->GetName().c_str(), error);
REPORT_INNER_ERROR("E19999", "[%s] Error:%d occurred while executing graph when SubgraphContext %s.",
graph_item_->GetName().c_str(), error, __FUNCTION__);
}
node_done_manager_.Destroy();
}
void SubgraphContext::NodeDone(const NodePtr &node) {
node_done_manager_.NodeDone(node);
}
} // namespace hybrid
} // namespace ge