/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "subgraph_context.h" #include "common/debug/log.h" #include "hybrid/executor/hybrid_model_executor.h" namespace ge { namespace hybrid { SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context) : graph_item_(graph_item), execution_context_(execution_context) { } Status SubgraphContext::Init() { GE_CHECK_NOTNULL(graph_item_); GELOGD("[%s] Start to init subgraph context. total inputs = %d, total outputs = %d", graph_item_->GetName().c_str(), graph_item_->TotalInputs(), graph_item_->TotalOutputs()); all_inputs_.resize(static_cast(graph_item_->TotalInputs())); all_outputs_.resize(static_cast(graph_item_->TotalOutputs())); return SUCCESS; } NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { std::lock_guard lk(mu_); auto &node_state = node_states_[node_item]; if (node_state == nullptr) { node_state.reset(new(std::nothrow)NodeState(*node_item, this)); } return node_state; } Status SubgraphContext::SetInput(int index, const TensorValue &tensor) { if (static_cast(index) >= all_inputs_.size()) { GELOGE(INTERNAL_ERROR, "[Check][Param:index]input index out of range. all input num = %zu, input index = %d", all_inputs_.size(), index); REPORT_INNER_ERROR("E19999", "input param index out of range when SubgraphContext %s, all input num = %zu, input index = %d.", __FUNCTION__, all_inputs_.size(), index); return INTERNAL_ERROR; } all_inputs_[index] = tensor; return SUCCESS; } Status SubgraphContext::SetInput(const NodeItem &node_item, int input_index, const TensorValue &tensor) { auto index = node_item.input_start + input_index; return SetInput(index, tensor); } Status SubgraphContext::SetOutput(const NodeItem &node_item, int output_index, const TensorValue &tensor) { auto index = node_item.output_start + output_index; if ((output_index >= node_item.num_outputs) || (static_cast(index) >= all_outputs_.size())) { GELOGE(INTERNAL_ERROR, "[Check][Param:output_index]output index out of range. all output num = %zu, node_item = %s," "output index = %d.", all_outputs_.size(), node_item.DebugString().c_str(), output_index); REPORT_INNER_ERROR("E19999", "output index out of range when SubgraphContext %s. " "all output num = %zu, node_item = %s, output index = %d.", __FUNCTION__, all_outputs_.size(), node_item.DebugString().c_str(), output_index); return INTERNAL_ERROR; } all_outputs_[index] = tensor; return SUCCESS; } Status SubgraphContext::GetInput(int index, TensorValue &tensor) { GE_CHECK_GE(all_inputs_.size(), index + 1U); tensor = all_inputs_[index]; return SUCCESS; } Status SubgraphContext::GetOutputs(std::vector &outputs) { if (graph_item_->IsDynamic()) { GELOGD("[%s] graph is dynamic, get outputs from net output input tensors", graph_item_->GetName().c_str()); // get from net output inputs auto output_node = graph_item_->GetOutputNode(); if (output_node != nullptr) { for (int i = 0; i < output_node->num_inputs; ++i) { TensorValue tensor; GE_CHK_STATUS_RET_NOLOG(GetInput(output_node->input_start + i, tensor)); GELOGD("[%s] Adding output tensor by input index [%d], tensor = %s", graph_item_->GetName().c_str(), output_node->input_start + i, tensor.DebugString().c_str()); outputs.emplace_back(std::move(tensor)); } } } else { GELOGD("[%s] graph is non-dynamic, get outputs from subgraph outputs", graph_item_->GetName().c_str()); for (auto &tensor : all_outputs_) { GELOGD("[%s] Adding output tensor: %s", graph_item_->GetName().c_str(), tensor.DebugString().c_str()); outputs.emplace_back(tensor); } } return SUCCESS; } Status SubgraphContext::Await(const NodePtr &node) { if (node_done_manager_.Await(node)) { return SUCCESS; } if (execution_context_->is_eos_) { return END_OF_SEQUENCE; } return FAILED; } void SubgraphContext::OnError(Status error) { if (error != END_OF_SEQUENCE) { GELOGE(error, "[Check][Param:error][%s] Error:%d occurred while executing graph.", graph_item_->GetName().c_str(), error); REPORT_INNER_ERROR("E19999", "[%s] Error:%d occurred while executing graph when SubgraphContext %s.", graph_item_->GetName().c_str(), error, __FUNCTION__); } node_done_manager_.Destroy(); } void SubgraphContext::NodeDone(const NodePtr &node) { node_done_manager_.NodeDone(node); } } // namespace hybrid } // namespace ge