|
|
|
@ -20,6 +20,7 @@
|
|
|
|
|
#include "hybrid/executor/hybrid_execution_context.h"
|
|
|
|
|
#include "hybrid/node_executor/aicore/aicore_task_builder.h"
|
|
|
|
|
#include "graph/load/new_model_manager/tbe_handle_store.h"
|
|
|
|
|
#include "graph/types.h"
|
|
|
|
|
|
|
|
|
|
using optiling::OpRunInfo;
|
|
|
|
|
|
|
|
|
@ -34,6 +35,23 @@ constexpr char const *kAttrAtomicOpParamSize = "atomic_op_para_size";
|
|
|
|
|
Status AiCoreOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) {
|
|
|
|
|
GE_CHK_STATUS_RET_NOLOG(InitWithTaskDef(op_desc, task_def));
|
|
|
|
|
GE_CHK_STATUS_RET_NOLOG(InitTilingInfo(op_desc));
|
|
|
|
|
|
|
|
|
|
GE_CHECK_LE(op_desc.GetOutputsSize(), static_cast<size_t>(INT_MAX));
|
|
|
|
|
int outputs_size = static_cast<int>(op_desc.GetOutputsSize());
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < outputs_size; ++i) {
|
|
|
|
|
const GeTensorDescPtr tensor_desc = op_desc.MutableOutputDesc(i);
|
|
|
|
|
if (tensor_desc == nullptr) {
|
|
|
|
|
GELOGW("Op: %s, Index: %d, Tensor Desc is null", op_desc.GetName().c_str(), i);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int32_t calc_type = 0;
|
|
|
|
|
bool ret = ge::AttrUtils::GetInt(tensor_desc, ATTR_NAME_MEMORY_SIZE_CALC_TYPE, calc_type);
|
|
|
|
|
if (ret && (calc_type == static_cast<int32_t>(ge::MemorySizeCalcType::ALWAYS_EMPTY))) {
|
|
|
|
|
output_indices_to_skip_.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -221,7 +239,8 @@ Status AiCoreOpTask::CalcTilingInfo(const NodePtr &node, OpRunInfo &tiling_info)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status AiCoreOpTask::UpdateArgs(TaskContext &task_context) {
|
|
|
|
|
size_t expected_arg_count = task_context.NumInputs() + task_context.NumOutputs() + task_context.NumWorkspaces();
|
|
|
|
|
size_t expected_arg_count = task_context.NumInputs() + task_context.NumOutputs() + task_context.NumWorkspaces()
|
|
|
|
|
- output_indices_to_skip_.size();
|
|
|
|
|
if (tiling_buffer_ != nullptr) {
|
|
|
|
|
++expected_arg_count;
|
|
|
|
|
}
|
|
|
|
@ -244,6 +263,11 @@ Status AiCoreOpTask::UpdateArgs(TaskContext &task_context) {
|
|
|
|
|
for (int i = 0; i < task_context.NumOutputs(); ++i) {
|
|
|
|
|
const auto output = task_context.GetOutput(i);
|
|
|
|
|
GE_CHECK_NOTNULL(output);
|
|
|
|
|
if (find(output_indices_to_skip_.begin(), output_indices_to_skip_.end(), i) != output_indices_to_skip_.end()) {
|
|
|
|
|
GELOGD("Node:%s output[%d] is an optional, the address don't need to be saved.",
|
|
|
|
|
task_context.GetNodeName(), i);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
arg_base_[index++] = reinterpret_cast<uintptr_t>(output->GetData());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|