|
|
|
@ -151,6 +151,9 @@ Status HybridModelBuilder::Build() {
|
|
|
|
|
GE_CHK_STATUS_RET(InitConstantOps(), "[Invoke][InitConstantOps] failed, model_name_:[%s]", GetGraphName());
|
|
|
|
|
GE_CHK_STATUS_RET(InitVariableTensors(), "[Invoke][InitVariableTensors], model_name_:[%s]", GetGraphName());
|
|
|
|
|
GE_CHK_STATUS_RET(LoadTasks(), "[Invoke][LoadTasks] failed, model_name_:[%s]", GetGraphName());
|
|
|
|
|
GE_CHK_STATUS_RET(OptimizeDependenciesForConstantInputs(),
|
|
|
|
|
"[Invoke][OptimizeDependenciesForConstantInputs] failed, model_name_:[%s]",
|
|
|
|
|
GetGraphName());
|
|
|
|
|
GELOGI("[%s] Done building hybrid model successfully.", GetGraphName());
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
@ -353,6 +356,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
|
|
|
|
|
auto src_node_item = MutableNodeItem(src_node);
|
|
|
|
|
src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx());
|
|
|
|
|
dependent_for_shape_inference.emplace(src_node);
|
|
|
|
|
host_input_value_dependencies_[&node_item].emplace_back(peer_out_anchor->GetIdx(), src_node_item);
|
|
|
|
|
GELOGD("[%s] Dependent added from output of [%s:%d]",
|
|
|
|
|
node_item.NodeName().c_str(),
|
|
|
|
|
src_node_item->NodeName().c_str(),
|
|
|
|
@ -1536,7 +1540,7 @@ Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) {
|
|
|
|
|
src_node->GetName().c_str(),
|
|
|
|
|
src_op_type.c_str());
|
|
|
|
|
|
|
|
|
|
if (src_op_type != CONSTANTOP && src_op_type != VARIABLE) {
|
|
|
|
|
if (src_op_type != CONSTANTOP && src_op_type != CONSTANT && src_op_type != VARIABLE) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1545,6 +1549,9 @@ Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) {
|
|
|
|
|
GELOGD("Got parent output index = %u", parent_index);
|
|
|
|
|
GE_CHECK_LE(parent_index, INT32_MAX);
|
|
|
|
|
node_item.ref_outputs.emplace(static_cast<int>(parent_index), src_node);
|
|
|
|
|
if (src_op_type == CONSTANTOP || src_op_type == CONSTANT) {
|
|
|
|
|
known_subgraph_constant_output_refs_[&node_item].emplace(parent_index, src_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Data nodes marked with REF_VAR_SRC_VAR_NAME
|
|
|
|
@ -2176,5 +2183,88 @@ Status HybridModelBuilder::ParseDependentByParallelGroup() {
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status HybridModelBuilder::OptimizeDependenciesForConstantInputs() {
|
|
|
|
|
std::map<NodePtr, std::set<uint32_t>> converted;
|
|
|
|
|
for (auto &it : host_input_value_dependencies_) {
|
|
|
|
|
auto node_item = it.first;
|
|
|
|
|
std::map<NodeItem *, int> ref_counts;
|
|
|
|
|
bool changed = false;
|
|
|
|
|
for (auto output_idx_and_node : it.second) {
|
|
|
|
|
auto output_idx = output_idx_and_node.first;
|
|
|
|
|
auto src_node_item = output_idx_and_node.second;
|
|
|
|
|
++ref_counts[src_node_item];
|
|
|
|
|
NodePtr constant_node;
|
|
|
|
|
if (src_node_item->node_type == CONSTANT || src_node_item->node_type == CONSTANTOP) {
|
|
|
|
|
constant_node = src_node_item->node;
|
|
|
|
|
GELOGD("src node [%s] is a constant", src_node_item->NodeName().c_str());
|
|
|
|
|
} else {
|
|
|
|
|
auto iter = known_subgraph_constant_output_refs_.find(src_node_item);
|
|
|
|
|
if (iter != known_subgraph_constant_output_refs_.end()) {
|
|
|
|
|
constant_node = iter->second[output_idx];
|
|
|
|
|
if (constant_node != nullptr) {
|
|
|
|
|
GELOGD("Output[%u] of subgraph [%s] is a constant", output_idx, src_node_item->NodeName().c_str());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (constant_node == nullptr) {
|
|
|
|
|
GELOGD("Output[%u] of [%s] is not a constant", output_idx, src_node_item->NodeName().c_str());
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (converted[constant_node].count(output_idx) == 0) {
|
|
|
|
|
GE_CHK_STATUS_RET(Convert2HostTensor(constant_node, src_node_item->node_id, output_idx),
|
|
|
|
|
"[%s] Failed to convert constant to host tensor", constant_node->GetName().c_str());
|
|
|
|
|
converted[constant_node].emplace(output_idx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
src_node_item->to_const_output_id_list.erase(output_idx);
|
|
|
|
|
--ref_counts[src_node_item];
|
|
|
|
|
changed = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (changed) {
|
|
|
|
|
std::vector<NodePtr> depends_to_keep;
|
|
|
|
|
for (auto &ref_count_it : ref_counts) {
|
|
|
|
|
if (ref_count_it.second == 0) {
|
|
|
|
|
GELOGD("[%s] no longer depends on [%s] for shape inference",
|
|
|
|
|
node_item->NodeName().c_str(),
|
|
|
|
|
ref_count_it.first->NodeName().c_str());
|
|
|
|
|
} else {
|
|
|
|
|
depends_to_keep.emplace_back(ref_count_it.first->node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
node_item->dependents_for_shape_inference.swap(depends_to_keep);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
Status HybridModelBuilder::Convert2HostTensor(const NodePtr &node, int node_id, uint32_t output_idx) {
|
|
|
|
|
auto tensor_value = hybrid_model_.GetTensor(node);
|
|
|
|
|
GE_CHECK_NOTNULL(tensor_value);
|
|
|
|
|
auto tensor_desc = node->GetOpDesc()->MutableOutputDesc(0);
|
|
|
|
|
GE_CHECK_NOTNULL(tensor_desc);
|
|
|
|
|
Tensor tensor(TensorAdapter::GeTensorDesc2TensorDesc(*tensor_desc));
|
|
|
|
|
int64_t tensor_size = -1;
|
|
|
|
|
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*tensor_desc, tensor_size),
|
|
|
|
|
"[%s] Failed to get tensor size", node->GetName().c_str());
|
|
|
|
|
if (tensor_size > 0) {
|
|
|
|
|
auto copy_size = static_cast<size_t>(tensor_size);
|
|
|
|
|
GE_CHECK_GE(tensor_value->GetSize(), copy_size);
|
|
|
|
|
std::vector<uint8_t> buffer(copy_size);
|
|
|
|
|
GE_CHK_RT_RET(rtMemcpy(buffer.data(),
|
|
|
|
|
copy_size,
|
|
|
|
|
tensor_value->GetData(),
|
|
|
|
|
copy_size,
|
|
|
|
|
RT_MEMCPY_DEVICE_TO_HOST));
|
|
|
|
|
tensor.SetData(std::move(buffer));
|
|
|
|
|
GELOGD("[%s] Copy constant tensor to host successfully, size = %zu", node->GetName().c_str(), copy_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
hybrid_model_.host_tensors_[node_id].emplace_back(output_idx, std::move(tensor));
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
} // namespace hybrid
|
|
|
|
|
} // namespace ge
|
|
|
|
|