You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
4.3 KiB
128 lines
4.3 KiB
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "subgraph_context.h"
|
|
|
|
#include "common/debug/log.h"
|
|
|
|
namespace ge {
|
|
namespace hybrid {
|
|
SubgraphContext::SubgraphContext(const GraphItem *graph_item) : graph_item_(graph_item) {
|
|
|
|
}
|
|
|
|
Status SubgraphContext::Init() {
|
|
GE_CHECK_NOTNULL(graph_item_);
|
|
GELOGD("[%s] Start to init subgraph context. total inputs = %d, total outputs = %d",
|
|
graph_item_->GetName().c_str(),
|
|
graph_item_->TotalInputs(),
|
|
graph_item_->TotalOutputs());
|
|
all_inputs_.resize(static_cast<unsigned long>(graph_item_->TotalInputs()));
|
|
all_outputs_.resize(static_cast<unsigned long>(graph_item_->TotalOutputs()));
|
|
|
|
return SUCCESS;
|
|
}
|
|
|
|
NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) {
|
|
std::lock_guard<std::mutex> lk(mu_);
|
|
auto &node_state = node_states_[node_item];
|
|
if (node_state == nullptr) {
|
|
node_state.reset(new(std::nothrow)NodeState(*node_item, this));
|
|
}
|
|
|
|
return node_state;
|
|
}
|
|
|
|
Status SubgraphContext::SetInput(int index, const TensorValue &tensor) {
|
|
if (static_cast<size_t>(index) >= all_inputs_.size()) {
|
|
GELOGE(INTERNAL_ERROR,
|
|
"output index output range. all input num = %zu, input index = %d",
|
|
all_inputs_.size(),
|
|
index);
|
|
return INTERNAL_ERROR;
|
|
}
|
|
all_inputs_[index] = tensor;
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status SubgraphContext::SetInput(const NodeItem &node_item, int input_index, const TensorValue &tensor) {
|
|
auto index = node_item.input_start + input_index;
|
|
return SetInput(index, tensor);
|
|
}
|
|
|
|
Status SubgraphContext::SetOutput(const NodeItem &node_item, int output_index, const TensorValue &tensor) {
|
|
auto index = node_item.output_start + output_index;
|
|
if ((output_index >= node_item.num_outputs) || (static_cast<size_t>(index) >= all_outputs_.size())) {
|
|
GELOGE(INTERNAL_ERROR,
|
|
"output index output range. all output num = %zu, node_item = %s, output index = %d",
|
|
all_outputs_.size(),
|
|
node_item.DebugString().c_str(),
|
|
output_index);
|
|
return INTERNAL_ERROR;
|
|
}
|
|
|
|
all_outputs_[index] = tensor;
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status SubgraphContext::GetInput(int index, TensorValue &tensor) {
|
|
GE_CHECK_GE(all_inputs_.size(), index + 1U);
|
|
tensor = all_inputs_[index];
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status SubgraphContext::GetOutputs(std::vector<TensorValue> &outputs) {
|
|
if (graph_item_->IsDynamic()) {
|
|
GELOGD("[%s] graph is dynamic, get outputs from net output input tensors", graph_item_->GetName().c_str());
|
|
// get from net output inputs
|
|
auto output_node = graph_item_->GetOutputNode();
|
|
if (output_node != nullptr) {
|
|
for (int i = 0; i < output_node->num_inputs; ++i) {
|
|
TensorValue tensor;
|
|
GE_CHK_STATUS_RET_NOLOG(GetInput(output_node->input_start + i, tensor));
|
|
GELOGD("[%s] Adding output tensor by input index [%d], tensor = %s",
|
|
graph_item_->GetName().c_str(),
|
|
output_node->input_start + i,
|
|
tensor.DebugString().c_str());
|
|
outputs.emplace_back(std::move(tensor));
|
|
}
|
|
}
|
|
} else {
|
|
GELOGD("[%s] graph is non-dynamic, get outputs from subgraph outputs", graph_item_->GetName().c_str());
|
|
for (auto &tensor : all_outputs_) {
|
|
GELOGD("[%s] Adding output tensor: %s", graph_item_->GetName().c_str(), tensor.DebugString().c_str());
|
|
outputs.emplace_back(tensor);
|
|
}
|
|
}
|
|
|
|
return SUCCESS;
|
|
}
|
|
|
|
bool SubgraphContext::Await(const NodePtr &node) {
|
|
return node_done_manager_.Await(node);
|
|
}
|
|
|
|
void SubgraphContext::OnError(Status error) {
|
|
GELOGE(error, "[%s] Error occurred while executing graph.", graph_item_->GetName().c_str());
|
|
node_done_manager_.Destroy();
|
|
}
|
|
|
|
void SubgraphContext::NodeDone(const NodePtr &node) {
|
|
node_done_manager_.NodeDone(node);
|
|
}
|
|
} // namespace hybrid
|
|
} // namespace ge
|