You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/build/task_generator.h

171 lines
6.7 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GRAPH_BUILD_TASK_GENERATOR_H_
#define GE_GRAPH_BUILD_TASK_GENERATOR_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "common/ge_inner_error_codes.h"
#include "common/opskernel/ops_kernel_info_types.h"
#include "framework/common/types.h"
#include "graph/compute_graph.h"
#include "graph/model.h"
#include "proto/task.pb.h"
#include "runtime/rt.h"
namespace ge {
class GELib;
class OpsKernelManager;
struct ProfilingPoint {
uint32_t fp_index = 0;
uint32_t bp_index = 0;
std::set<uint32_t> end_index;
};
// Describes infos needed by generate task for fusion node
struct FusionTaskInfo {
RunContext &run_context;
ComputeGraphPtr &graph;
NodePtr &node;
OpDescPtr &fusion_op_desc;
uint32_t &node_index;
std::shared_ptr<GELib> &ge_lib;
const OpsKernelManager &ops_kernel_manager;
std::vector<domi::TaskDef> &task_def_list;
std::map<uint32_t, string> &op_name_map;
ProfilingPoint &profiling_point;
vector<uint32_t> all_reduce_nodes;
uint64_t all_reduce_node_idx;
};
class TaskGenerator {
public:
TaskGenerator() = default;
TaskGenerator(const TaskGenerator &) = delete;
TaskGenerator &operator=(const TaskGenerator &) = delete;
virtual ~TaskGenerator();
TaskGenerator(uint8_t *var_mem_base, uint64_t var_mem_size);
///
/// get task info.
/// @param model model
/// @param graph compute graph
/// @param buffer weights buffer
/// @param session_id session id
/// @return SUCCESS: success
/// other:failed
///
Status GetTaskInfo(Model &model, ComputeGraphPtr &graph, uint64_t session_id, RunContext &run_context);
Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
std::vector<uint32_t> &all_reduce_nodes);
private:
Status UpdateAnchorStatus(const NodePtr &node);
Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id);
///
/// call engine to generate known shape task.
/// @param run_context run context
/// @param graph compute graph
/// @param task_def_list task def list generate by engine
/// @param op_name_map relation of task index and op
/// @return SUCCESS:seccess
/// Other: failed
///
Status GenerateTask(RunContext &run_context, ComputeGraphPtr &graph, std::vector<domi::TaskDef> &task_def_list,
std::map<uint32_t, string> &op_name_map);
///
/// AddModelTaskToModel
/// @param model_task_def model task
/// @param model_def model
/// @return SUCCESS:seccess
/// Other: failed
///
Status AddModelTaskToModel(const domi::ModelTaskDef &model_task_def, uint64_t session_id, Model &model_def,
RunContext &run_context);
Status MarkNodeAndSetIndex(ComputeGraphPtr &graph);
// Mark first and last op according to the same stream and engine
Status MarkFirstAndLastOps(const vector<OpDescPtr> &ops, bool is_single_stream) const;
// profiling interface
Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const;
Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
vector<uint32_t> &all_reduce_nodes) const;
uint32_t FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const;
Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str,
ProfilingPoint &profiling_point) const;
Status FindBpOfEnv(const ComputeGraphPtr &graph, const std::string &bp_point_str, ProfilingPoint &profiling_point,
vector<uint32_t> &all_reduce_nodes) const;
Status GetFpBpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, vector<uint32_t> &all_reduce_nodes,
std::string& fp_point_str, std::string& bp_point_str) const;
Status FindProfilingTaskIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
std::vector<uint32_t> &all_reduce_nodes) const;
Status InsertProfilingTaskBefore(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point,
std::vector<uint32_t> &all_reduce_nodes, uint32_t node_index,
std::vector<domi::TaskDef> &task_def_list);
Status InsertProfilingArTaskBefore(const OpDescPtr &op_desc, std::vector<uint32_t> &all_reduce_nodes,
uint32_t node_index, std::vector<domi::TaskDef> &task_def_listy,
bool is_insert_bp_profiling_task);
Status InsertProfilingTaskAfter(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point,
std::vector<uint32_t> &all_reduce_nodes, uint32_t node_index,
std::vector<domi::TaskDef> &task_def_list);
Status InsertProfilingArTaskAfter(const OpDescPtr &op_desc, std::vector<uint32_t> &all_reduce_nodes,
uint32_t node_index, std::vector<domi::TaskDef> &task_def_list,
bool is_insert_bp_profiling_task);
static bool IsProfPoint(const OpDescPtr &op, const std::string &name);
/// call engine to generate task for fusion node.
/// @param FusionTaskInfo
/// @param fusion_nodes: nodes in graph with groud_id attr which means fusion node
/// @param fusion_nodes_seen: fusion node has been called generate task
/// @return SUCCESS:seccess
/// Other: failed
///
Status GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info,
std::map<int64_t, std::vector<NodePtr>> &fusion_nodes,
std::unordered_set<Node *> &fusion_nodes_seen);
Status SaveFusionNodes(map<int64_t, std::vector<NodePtr>> &fusion_nodes, ComputeGraphPtr &graph);
Status SetUnknownShapeStream(RunContext &run_context, rtStream_t &stream);
Status DestroyUnknownShapeStream(RunContext &run_context, rtStream_t &stream);
Status SetKnownShapeStream(RunContext &run_context, int64_t stream_id);
bool IsSubGraphOfDynamicGraph(const ComputeGraphPtr &graph) const;
uint8_t *var_mem_base_ = nullptr;
uint64_t var_mem_size_ = 0;
};
} // namespace ge
#endif // GE_GRAPH_BUILD_TASK_GENERATOR_H_