You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/manager/graph_context.cc

108 lines
3.8 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/manager/graph_context.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
namespace ge {
GraphContext::GraphContext(const GraphNodePtr &graph_node) {
if (graph_node == nullptr) {
GELOGE(GE_GRAPH_PARAM_NULLPTR, "graphNode is NULL!");
return;
}
compute_graph_ = graph_node->GetComputeGraph();
current_graph_id_ = graph_node->GetGraphId();
if (compute_graph_ == nullptr) {
std::shared_ptr<const ge::Graph> graph = graph_node->GetGraph();
if (graph == nullptr) {
GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "compute_graph by graphNode is NULL!");
return;
}
compute_graph_ = GraphUtils::GetComputeGraph(*graph);
return;
}
}
Status GraphContext::SetComputeGraph(const GraphNodePtr &graph_node) {
if (graph_node == nullptr) {
REPORT_INNER_ERROR("E19999", "Param graph_node is nullptr, check invalid");
GELOGE(GE_GRAPH_PARAM_NULLPTR, "graphNode is NULL!");
return GE_GRAPH_PARAM_NULLPTR;
}
compute_graph_ = graph_node->GetComputeGraph();
current_graph_id_ = graph_node->GetGraphId();
if (compute_graph_ == nullptr) {
std::shared_ptr<const ge::Graph> graph = graph_node->GetGraph();
if (graph == nullptr) {
REPORT_INNER_ERROR("E19999", "Param graph in graph_node is nullptr, check invalid");
GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "compute_graph by graphNode is NULL!");
return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL;
}
compute_graph_ = GraphUtils::GetComputeGraph(*graph);
return SUCCESS;
}
return SUCCESS;
}
Status GraphContext::Initialize(const std::map<std::string, std::string> &options) const { return SUCCESS; }
Status GraphContext::Finalize() const { return SUCCESS; }
Status GraphContext::GetVariableTensor(const std::string &var_data_name, GeTensor &returned_tensor) {
if (var_data_name.empty()) {
REPORT_INNER_ERROR("E19999", "Param var_data_name is empty, check invalid");
GELOGE(GE_GRAPH_EMPTY_STRING_NAME, "Variable data name is empty!");
return GE_GRAPH_EMPTY_STRING_NAME;
}
if (GetVarNodeTensorTable().empty()) {
REPORT_INNER_ERROR("E19999", "VarNodeTensorTable is empty, var_data_name:%s, check invalid",
var_data_name.c_str());
GELOGE(GE_GRAPH_EMPTY_VARIABLE_TENSOR_TABLE, "VarNodeTensorTable is empty!");
return GE_GRAPH_EMPTY_VARIABLE_TENSOR_TABLE;
}
for (auto &var_record : GetVarNodeTensorTable()) {
if (var_data_name == std::get<0>(var_record.first)) {
returned_tensor.SetTensorDesc(var_record.second.GetTensorDesc());
auto ret = returned_tensor.SetData(var_record.second.GetData());
if (ret != SUCCESS) {
REPORT_INNER_ERROR("E19999", "SetData to tensor fail, var_data_name:%s",
var_data_name.c_str());
GELOGE(ret, "Set Tensor data failed!");
return ret;
}
return SUCCESS;
}
}
REPORT_INNER_ERROR("E19999", "VarRecord with data_name:%s does not exist, check invalid",
var_data_name.c_str());
GELOGE(GE_GRAPH_VARIABLE_DOES_NOT_EXIST, "VarRecord with data_name %s does NOT exist!", var_data_name.c_str());
return GE_GRAPH_VARIABLE_DOES_NOT_EXIST;
}
} // namespace ge