!1062 Fix bug of acl multi_task.

From: @zhao_zhixuan
Reviewed-by: @xchu42,@ji_chen
Signed-off-by: @ji_chen
pull/1062/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit d991e8423a

@ -70,7 +70,6 @@ Status AiCoreTaskBuilder::BuildTask(std::unique_ptr<NodeTask> &node_task,
auto atomic_task = auto atomic_task =
std::unique_ptr<AtomicAddrCleanOpTask>(new(std::nothrow)AtomicAddrCleanOpTask()); std::unique_ptr<AtomicAddrCleanOpTask>(new(std::nothrow)AtomicAddrCleanOpTask());
GE_CHECK_NOTNULL(atomic_task); GE_CHECK_NOTNULL(atomic_task);
atomic_task->SetSingleOp(is_single_op);
GE_CHK_STATUS_RET(atomic_task->Init(*op_desc_, task_defs_.front()), GE_CHK_STATUS_RET(atomic_task->Init(*op_desc_, task_defs_.front()),
"[%s] Failed to init task for AtomicAddrClean", "[%s] Failed to init task for AtomicAddrClean",
op_desc_->GetName().c_str()); op_desc_->GetName().c_str());

@ -43,20 +43,21 @@ using std::vector;
namespace ge { namespace ge {
namespace { namespace {
const size_t kDataOutputNum = 1; const size_t kDataOutputNum = 1;
} // namespace
static Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { bool NeedHybridModel(GeModelPtr &ge_model) {
auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); auto tasks = ge_model->GetModelTaskDefPtr()->task();
for (const auto &node : comp_graph->GetAllNodes()) { int32_t kernel_task_num = 0;
auto op_desc = node->GetOpDesc(); for (int i = 0; i < tasks.size(); ++i) {
GE_CHECK_NOTNULL(op_desc); if (static_cast<rtModelTaskType_t>(tasks[i].type()) == RT_MODEL_TASK_KERNEL) {
const auto &depends = op_desc->GetOpInferDepends(); kernel_task_num++;
if (!depends.empty()) { if (kernel_task_num > 1) {
flag = true; return true;
return SUCCESS;
} }
} }
return SUCCESS;
} }
return false;
}
} // namespace
SingleOpModel::SingleOpModel(const std::string &model_name, const void *model_data, uint32_t model_size) SingleOpModel::SingleOpModel(const std::string &model_name, const void *model_data, uint32_t model_size)
: model_name_(model_name), ori_model_data_(model_data), ori_model_size_(model_size) {} : model_name_(model_name), ori_model_data_(model_data), ori_model_size_(model_size) {}
@ -497,9 +498,7 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp &
auto ge_model = model_helper_.GetGeModel(); auto ge_model = model_helper_.GetGeModel();
GE_CHECK_NOTNULL(ge_model); GE_CHECK_NOTNULL(ge_model);
bool infer_depend_flag = false; if (NeedHybridModel(ge_model)) {
GE_CHK_STATUS_RET_NOLOG(IfInferDepend(ge_model, infer_depend_flag));
if (ge_model->GetModelTaskDefPtr()->task_size() > 1 || infer_depend_flag) {
GELOGD("Build single op HybridModel."); GELOGD("Build single op HybridModel.");
GE_CHK_STATUS_RET_NOLOG(hybrid::NodeExecutorManager::GetInstance().EnsureInitialized()); GE_CHK_STATUS_RET_NOLOG(hybrid::NodeExecutorManager::GetInstance().EnsureInitialized());
auto root_model = model_helper_.GetGeRootModel(); auto root_model = model_helper_.GetGeRootModel();

Loading…
Cancel
Save