fix Profiling

pull/1353/head
chuxing 4 years ago
parent deebe05906
commit 795a935d34

@ -1072,30 +1072,39 @@ Status HybridModelBuilder::InitWeights() {
return SUCCESS;
}
Status HybridModelBuilder::LoadTask(NodeItem &node_item) {
auto &node_ptr = node_item.node;
GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str());
auto load_ret = node_item.node_executor->LoadTask(hybrid_model_,
node_ptr,
node_item.kernel_task);
if (load_ret != UNSUPPORTED && load_ret != SUCCESS) {
GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str());
return load_ret;
}
GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str());
return SUCCESS;
}
Status HybridModelBuilder::LoadTasks() {
GE_CHK_STATUS_RET(CheckAicpuOpList(), "Check Aicpu op failed.");
std::map<int64_t, NodeItem *> ordered_node_items;
std::map<int64_t, NodeItem *> ordered_partitioned_calls;
for (auto &it : hybrid_model_.node_items_) {
auto &node_item = it.second;
ordered_node_items.emplace(node_item->node_id, node_item.get());
}
for (auto &it : ordered_node_items) {
auto &node_item = it.second;
auto &node_ptr = node_item->node;
if (node_item->node_type == NETOUTPUT) {
continue;
}
GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str());
auto load_ret = node_item->node_executor->LoadTask(hybrid_model_,
node_ptr,
node_item->kernel_task);
if (load_ret != UNSUPPORTED && load_ret != SUCCESS) {
GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str());
return load_ret;
if (node_item->node_type == PARTITIONEDCALL) {
ordered_partitioned_calls.emplace(node_item->node_id, node_item.get());
}
GE_CHK_STATUS_RET_NOLOG(LoadTask(*node_item));
}
GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str());
// HCCL operators need to be loaded in the same order across different processes
for (auto &it : ordered_partitioned_calls) {
GE_CHK_STATUS_RET_NOLOG(LoadTask(*it.second));
}
return SUCCESS;

@ -57,6 +57,7 @@ class HybridModelBuilder {
Status ValidateParams();
Status LoadGraph();
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model);
Status LoadTask(NodeItem &node_item);
Status LoadTasks();
Status IdentifyVariableOutputs(NodeItem &node_item);
Status IdentifySameInputs(NodeItem &node_item);

Loading…
Cancel
Save