|
|
@ -18,7 +18,6 @@
|
|
|
|
#define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
|
|
|
|
#define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
|
|
|
|
|
|
|
|
|
|
|
|
#include <stdint.h>
|
|
|
|
#include <stdint.h>
|
|
|
|
#include <functional>
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
#include <memory>
|
|
|
|
#include <string>
|
|
|
|
#include <string>
|
|
|
|
#include <utility>
|
|
|
|
#include <utility>
|
|
|
@ -219,9 +218,9 @@ class LabelSwitchTaskInfo : public TaskInfo {
|
|
|
|
label_list_(label_list),
|
|
|
|
label_list_(label_list),
|
|
|
|
cond_(cond) {}
|
|
|
|
cond_(cond) {}
|
|
|
|
~LabelSwitchTaskInfo() override {}
|
|
|
|
~LabelSwitchTaskInfo() override {}
|
|
|
|
uint32_t label_size() { return label_size_; };
|
|
|
|
uint32_t label_size() const { return label_size_; }
|
|
|
|
const std::vector<uint32_t> &label_list() { return label_list_; };
|
|
|
|
const std::vector<uint32_t> &label_list() const { return label_list_; }
|
|
|
|
void *cond() { return cond_; };
|
|
|
|
void *cond() const { return cond_; }
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
uint32_t label_size_;
|
|
|
|
uint32_t label_size_;
|
|
|
@ -236,7 +235,7 @@ class EventTaskInfo : public TaskInfo {
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
|
|
|
|
EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
|
|
|
|
: TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
|
|
|
|
: TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
|
|
|
|
virtual ~EventTaskInfo() override {}
|
|
|
|
~EventTaskInfo() override {}
|
|
|
|
|
|
|
|
|
|
|
|
uint32_t event_id_;
|
|
|
|
uint32_t event_id_;
|
|
|
|
};
|
|
|
|
};
|
|
|
@ -272,16 +271,13 @@ class FusionEndTaskInfo : public TaskInfo {
|
|
|
|
class HcclTaskInfo : public TaskInfo {
|
|
|
|
class HcclTaskInfo : public TaskInfo {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
|
|
|
|
HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
|
|
|
|
void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
|
|
|
|
void *output_data_addr, int64_t workspace_size, int64_t hccl_stream_num,
|
|
|
|
const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
|
|
|
|
const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
|
|
|
|
int64_t op_type, int64_t data_type, const std::string &group,
|
|
|
|
int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag)
|
|
|
|
std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model,
|
|
|
|
|
|
|
|
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task, bool dump_flag)
|
|
|
|
|
|
|
|
: TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
|
|
|
|
: TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
|
|
|
|
hccl_type_(hccl_type),
|
|
|
|
hccl_type_(hccl_type),
|
|
|
|
input_data_addr_(input_data_addr),
|
|
|
|
input_data_addr_(input_data_addr),
|
|
|
|
output_data_addr_(output_data_addr),
|
|
|
|
output_data_addr_(output_data_addr),
|
|
|
|
workspace_addr_(workspace_addr),
|
|
|
|
|
|
|
|
workspace_size_(workspace_size),
|
|
|
|
workspace_size_(workspace_size),
|
|
|
|
hccl_stream_num_(hccl_stream_num),
|
|
|
|
hccl_stream_num_(hccl_stream_num),
|
|
|
|
private_def_(private_def),
|
|
|
|
private_def_(private_def),
|
|
|
@ -290,16 +286,12 @@ class HcclTaskInfo : public TaskInfo {
|
|
|
|
root_id_(root_id),
|
|
|
|
root_id_(root_id),
|
|
|
|
op_type_(op_type),
|
|
|
|
op_type_(op_type),
|
|
|
|
data_type_(data_type),
|
|
|
|
data_type_(data_type),
|
|
|
|
group_(group),
|
|
|
|
group_(group) {}
|
|
|
|
hcom_bind_model_(hcom_bind_model),
|
|
|
|
|
|
|
|
hcom_unbind_model_(hcom_unbind_model),
|
|
|
|
|
|
|
|
hcom_distribute_task_(hcom_distribute_task) {}
|
|
|
|
|
|
|
|
~HcclTaskInfo() override {}
|
|
|
|
~HcclTaskInfo() override {}
|
|
|
|
|
|
|
|
|
|
|
|
const std::string &hccl_type() const { return hccl_type_; }
|
|
|
|
const std::string &hccl_type() const { return hccl_type_; }
|
|
|
|
void *input_data_addr() const { return input_data_addr_; }
|
|
|
|
void *input_data_addr() const { return input_data_addr_; }
|
|
|
|
void *output_data_addr() const { return output_data_addr_; }
|
|
|
|
void *output_data_addr() const { return output_data_addr_; }
|
|
|
|
void *workspace_addr() const { return workspace_addr_; }
|
|
|
|
|
|
|
|
int64_t workspace_size() const { return workspace_size_; }
|
|
|
|
int64_t workspace_size() const { return workspace_size_; }
|
|
|
|
int64_t hccl_stream_num() const { return hccl_stream_num_; }
|
|
|
|
int64_t hccl_stream_num() const { return hccl_stream_num_; }
|
|
|
|
const std::vector<uint8_t> &private_def() const { return private_def_; }
|
|
|
|
const std::vector<uint8_t> &private_def() const { return private_def_; }
|
|
|
@ -309,17 +301,11 @@ class HcclTaskInfo : public TaskInfo {
|
|
|
|
int64_t op_type() const { return op_type_; }
|
|
|
|
int64_t op_type() const { return op_type_; }
|
|
|
|
int64_t data_type() const { return data_type_; }
|
|
|
|
int64_t data_type() const { return data_type_; }
|
|
|
|
const std::string &group() const { return group_; }
|
|
|
|
const std::string &group() const { return group_; }
|
|
|
|
std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; }
|
|
|
|
|
|
|
|
std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; }
|
|
|
|
|
|
|
|
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const {
|
|
|
|
|
|
|
|
return hcom_distribute_task_;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
std::string hccl_type_;
|
|
|
|
std::string hccl_type_;
|
|
|
|
void *input_data_addr_;
|
|
|
|
void *input_data_addr_;
|
|
|
|
void *output_data_addr_;
|
|
|
|
void *output_data_addr_;
|
|
|
|
void *workspace_addr_;
|
|
|
|
|
|
|
|
int64_t workspace_size_;
|
|
|
|
int64_t workspace_size_;
|
|
|
|
int64_t hccl_stream_num_;
|
|
|
|
int64_t hccl_stream_num_;
|
|
|
|
std::vector<uint8_t> private_def_;
|
|
|
|
std::vector<uint8_t> private_def_;
|
|
|
@ -329,9 +315,6 @@ class HcclTaskInfo : public TaskInfo {
|
|
|
|
int64_t op_type_;
|
|
|
|
int64_t op_type_;
|
|
|
|
int64_t data_type_;
|
|
|
|
int64_t data_type_;
|
|
|
|
std::string group_;
|
|
|
|
std::string group_;
|
|
|
|
std::function<bool(void *, void *)> hcom_bind_model_;
|
|
|
|
|
|
|
|
std::function<bool(void *)> hcom_unbind_model_;
|
|
|
|
|
|
|
|
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_;
|
|
|
|
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class ProfilerTraceTaskInfo : public TaskInfo {
|
|
|
|
class ProfilerTraceTaskInfo : public TaskInfo {
|
|
|
|