You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/metadef/graph/graph.cc

385 lines
13 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "external/graph/graph.h"
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/debug/ge_op_types.h"
#include "graph/model.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
using std::map;
using std::pair;
using std::string;
using std::vector;
namespace ge {
class GraphImpl {
public:
friend class GraphUtils;
GraphImpl(const GraphImpl &) = delete;
GraphImpl &operator=(const GraphImpl &) = delete;
explicit GraphImpl(const std::string &name) : name_(name) {}
~GraphImpl() {
if (IsValid()) {
if (compute_graph_ != nullptr) {
GraphUtils::BreakConnect(compute_graph_->GetAllNodesInfo());
}
}
for (const auto &it : op_list_) {
Operator op = it.second;
op.BreakConnect();
}
}
graphStatus SetInputs(const std::vector<Operator> &inputs) {
compute_graph_ = GraphUtils::CreateGraphFromOperator(name_, inputs);
GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "Build Graph failed.");
GE_CHK_BOOL_RET_STATUS(inputs.size() != 0, GRAPH_FAILED, "set input NULL.");
compute_graph_->SetInputSize(static_cast<uint32_t>(inputs.size()));
return GRAPH_SUCCESS;
}
graphStatus SetOutputs(const std::vector<Operator> &outputs) {
if (compute_graph_ == nullptr) {
GELOGE(GRAPH_FAILED, "set ComputeGraph failed.");
return GRAPH_FAILED;
}
if (outputs.empty()) {
GELOGW("set outputs size is 0.");
return GRAPH_SUCCESS;
}
// Construct special output node
std::vector<std::pair<Operator, std::vector<size_t>>> output_indexs;
for (size_t i = 0; i < outputs.size(); ++i) {
output_indexs.emplace_back(outputs[i], std::vector<size_t>{});
}
graphStatus ret = SetOutputs(output_indexs);
return ret;
}
graphStatus SetOutputs(const std::vector<std::pair<Operator, std::vector<size_t>>> &output_indexs) {
if (compute_graph_ == nullptr) {
GELOGE(GRAPH_FAILED, "set ComputeGraph failed.");
return GRAPH_FAILED;
}
if (output_indexs.empty()) {
GELOGW("set outputs size is 0.");
return GRAPH_SUCCESS;
}
// Construct special output node
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes;
for (const auto &item : output_indexs) {
const Operator &output = item.first;
const vector<size_t> &indexs = item.second;
ge::NodePtr node = compute_graph_->FindNode(output.GetName());
if (node == nullptr) {
GELOGW("user designated out_node [%s] not exist in graph, will ignored!", output.GetName().c_str());
continue;
}
ge::OpDescPtr tmp_op_ptr = node->GetOpDesc();
GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue);
size_t out_size = tmp_op_ptr->GetOutputsSize();
if (indexs.empty()) {
for (size_t i = 0; i < out_size; ++i) {
output_name_ += output.GetName() + ":" + std::to_string(i) + ";";
output_nodes.emplace_back(node, i);
}
} else {
for (size_t i = 0; i < indexs.size(); ++i) {
if (indexs[i] >= out_size) {
GELOGW("index[%zu] is not belong to out_node[%s]", indexs[i], output.GetName().c_str());
} else {
output_name_ += output.GetName() + ":" + std::to_string(i) + ";";
output_nodes.emplace_back(node, indexs[i]);
}
}
}
}
// Del last ";"
if (!output_name_.empty()) {
output_name_ = output_name_.substr(0, output_name_.length() - 1);
}
compute_graph_->SetUserDefOutput(output_name_);
compute_graph_->SetOutputSize(static_cast<uint32_t>(output_indexs.size()));
compute_graph_->SetGraphOutNodesInfo(output_nodes);
return GRAPH_SUCCESS;
}
graphStatus SetOutputs(const std::vector<pair<Operator, string>> &outputs) {
GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild.");
GE_CHK_BOOL_EXEC_INFO(outputs.size() != 0, return GRAPH_SUCCESS, "set outputs size is 0.");
// Construct specified output
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes;
for (auto item : outputs) {
ge::NodePtr node = compute_graph_->FindNode(item.first.GetName());
if (node == nullptr) {
GELOGE(GRAPH_FAILED, " Warning, user designated out_node (%s) not exist in graph, this out_node ignored!",
item.first.GetName().c_str());
return GRAPH_FAILED;
}
ge::OpDescPtr tmp_op_ptr = node->GetOpDesc();
GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue);
size_t out_size = tmp_op_ptr->GetOutputsSize();
if (item.second.empty()) {
for (size_t i = 0; i < out_size; ++i) {
output_name_ += item.first.GetName() + ":" + std::to_string(i) + ";";
output_nodes.push_back(std::make_pair(node, i));
}
} else {
int32_t index = tmp_op_ptr->GetOutputIndexByName(item.second);
if (index < 0) {
GELOGE(GRAPH_FAILED,
" Warning, user designated out_node (%s):(%s) not exist in graph, this out_node ignored!",
item.first.GetName().c_str(), item.second.c_str());
return GRAPH_FAILED;
}
output_name_ += item.first.GetName() + ":" + std::to_string(index) + ";";
output_nodes.push_back(std::make_pair(node, index));
}
}
// Del last ";"
if (!output_name_.empty()) {
output_name_ = output_name_.substr(0, output_name_.length() - 1);
}
compute_graph_->SetOutputSize(static_cast<uint32_t>(outputs.size()));
compute_graph_->SetGraphOutNodesInfo(output_nodes);
GELOGI("********************SetOutputs Success***********************");
GE_IF_BOOL_EXEC(!output_name_.empty(), GELOGI(" NetOutputs: (%s)", output_name_.c_str()));
return GRAPH_SUCCESS;
}
graphStatus SetTargets(const std::vector<Operator> &targets) {
GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild.");
GE_CHK_BOOL_EXEC_INFO(targets.size() != 0, return GRAPH_SUCCESS, "set targets size is 0.");
std::vector<ge::NodePtr> target_nodes;
for (auto item : targets) {
ge::NodePtr node = compute_graph_->FindNode(item.GetName());
if (node == nullptr) {
GELOGW(" Warning, user designated target_node (%s) not exist in graph, this target_node ignored!",
item.GetName().c_str());
continue;
}
target_nodes.push_back(node);
}
compute_graph_->SetGraphTargetNodesInfo(target_nodes);
return GRAPH_SUCCESS;
}
bool IsValid() const { return (compute_graph_ != nullptr); }
graphStatus AddOp(const ge::Operator &op) {
std::pair<std::map<string, ge::Operator>::iterator, bool> ret;
ret = op_list_.emplace(std::pair<string, ge::Operator>(op.GetName(), op));
GE_CHK_BOOL_RET_STATUS(ret.second != false, GRAPH_FAILED, "the op have added before, op name:%s.",
op.GetName().c_str());
return GRAPH_SUCCESS;
}
graphStatus GetAllOpName(std::vector<string> &op_name) const {
for (const auto &it : op_list_) {
op_name.push_back(it.second.GetName());
}
return GRAPH_SUCCESS;
}
graphStatus FindOpByName(const string &name, ge::Operator &op) const {
auto it = op_list_.find(name);
GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str());
op = it->second;
return GRAPH_SUCCESS;
}
graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const {
for (auto &op : op_list_) {
auto op_type = op.second.GetOpType();
if (op_type == type) {
ops.push_back(op.second);
continue;
}
if (op_type == ge::FRAMEWORKOP) {
op.second.GetAttr(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, op_type);
if (op_type == type) {
ops.push_back(op.second);
}
}
}
return GRAPH_SUCCESS;
}
void SetNeedIteration(bool need_iteration) {
if (compute_graph_ == nullptr) {
GELOGE(GRAPH_FAILED, "Set need iteration failed, as compute graph is null.");
return;
}
compute_graph_->SetNeedIteration(need_iteration);
}
const std::string &GetName() const { return name_; }
private:
std::string name_;
std::string output_name_;
std::map<string, ge::Operator> op_list_;
ComputeGraphPtr compute_graph_{nullptr};
};
Graph::Graph(const std::string &name) {
impl_ = ComGraphMakeShared<GraphImpl>(name);
if (impl_ == nullptr) {
GELOGW("GraphImpl make shared failed, impl_ is nullptr");
}
}
graphStatus Graph::AddOp(const ge::Operator &op) {
GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, "AddOp failed: graph can not be used, impl is nullptr.");
return impl_->AddOp(op);
}
graphStatus Graph::GetAllOpName(std::vector<string> &op_name) const {
GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED,
"GetAllOpName failed: graph can not be used, impl is nullptr.");
return impl_->GetAllOpName(op_name);
}
graphStatus Graph::FindOpByName(const std::string &name, Operator &op) const {
Operator op_find_op_def("NULL");
op = op_find_op_def;
GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED,
"FindOpByName failed: graph can not be used, impl is nullptr.");
return impl_->FindOpByName(name, op);
}
graphStatus Graph::FindOpByType(const string &type, std::vector<ge::Operator> &ops) const {
GE_CHECK_NOTNULL(impl_);
return impl_->FindOpByType(type, ops);
}
Graph &Graph::SetInputs(const vector<ge::Operator> &inputs) {
GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetInputs failed: graph can not be used, impl is nullptr.")
GE_CHK_BOOL_EXEC(inputs.size() > 0, return *this, "SetInputs failed: input operator size can not be 0.");
(void)impl_->SetInputs(inputs);
return *this;
}
Graph &Graph::SetOutputs(const vector<ge::Operator> &outputs) {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr.");
return *this;
}
(void)impl_->SetOutputs(outputs);
return *this;
}
Graph &Graph::SetOutputs(const std::vector<std::pair<Operator, std::vector<size_t>>> &output_indexs) {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr.");
return *this;
}
(void)impl_->SetOutputs(output_indexs);
return *this;
}
Graph &Graph::SetOutputs(const std::vector<pair<Operator, string>> &outputs) {
GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetOutputs failed: graph can not be used, impl is nullptr.")
(void)impl_->SetOutputs(outputs);
return *this;
}
Graph &Graph::SetTargets(const vector<ge::Operator> &targets) {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "SetTargets failed: graph can not be used, impl is nullptr.");
return *this;
}
(void)impl_->SetTargets(targets);
return *this;
}
bool Graph::IsValid() const {
if (impl_ == nullptr) {
return false;
}
return impl_->IsValid();
}
void Graph::SetNeedIteration(bool need_iteration) {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "Set need iteration failed, as impl is null.");
return;
}
impl_->SetNeedIteration(need_iteration);
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::GetComputeGraph(const ge::Graph &graph) {
GE_CHK_BOOL_EXEC_NOLOG(graph.IsValid(), return nullptr);
return graph.impl_->compute_graph_;
}
graphStatus Graph::SaveToFile(const string &file_name) const {
Model model = Model();
model.SetGraph(*this);
return model.SaveToFile(file_name);
}
graphStatus Graph::LoadFromFile(const string &file_name) {
Model model = Model();
graphStatus ret = model.LoadFromFile(file_name);
if (ret != GRAPH_SUCCESS) {
return ret;
}
*this = model.GetGraph();
return GRAPH_SUCCESS;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string &Graph::GetName() const { return impl_->GetName(); }
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph
GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) {
GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return Graph(""));
auto name = compute_graph->GetName();
auto graph = Graph(name);
GE_CHK_BOOL_EXEC_NOLOG(graph.impl_ != nullptr, return graph);
graph.impl_->compute_graph_ = compute_graph;
return graph;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) {
GE_CHECK_NOTNULL(graph.impl_);
GE_CHECK_NOTNULL(graph.impl_->compute_graph_);
graph.impl_->op_list_.clear();
for (const auto &node : graph.impl_->compute_graph_->GetDirectNode()) {
graph.impl_->op_list_[node->GetName()] = OpDescUtils::CreateOperatorFromNode(node);
}
return SUCCESS;
}
} // namespace ge