From 34d6d17ccfbff84b16655ea8e57a17ab8429a180 Mon Sep 17 00:00:00 2001 From: "gengchao4@huawei.com" Date: Wed, 2 Dec 2020 16:35:24 +0800 Subject: [PATCH] add special size for optune --- ge/graph/build/graph_builder.cc | 26 ++++++++++++------- .../load/new_model_manager/davinci_model.cc | 7 ++++- metadef | 2 +- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/ge/graph/build/graph_builder.cc b/ge/graph/build/graph_builder.cc index 79e46f50..e434709a 100644 --- a/ge/graph/build/graph_builder.cc +++ b/ge/graph/build/graph_builder.cc @@ -448,24 +448,32 @@ Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { auto node_op_desc = node_ptr->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); // set dst_node.input_desc = src_node.output_desc - ge::GeTensorDesc desc_temp(src_op->GetOutputDesc(peer_out_anchor->GetIdx())); - + auto output_desc = src_op->GetOutputDescPtr(peer_out_anchor->GetIdx()); int64_t size = 0; - GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc_temp, size) != SUCCESS, GELOGI("Get size failed!")); + GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_desc, size) != SUCCESS, GELOGI("Get size failed!")); GELOGD("src node %s output desc, dim_size: %zu, mem_size: %ld, format: %s, type: %s.", src_node->GetName().c_str(), - desc_temp.GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(desc_temp.GetFormat()).c_str(), - TypeUtils::DataTypeToSerialString(desc_temp.GetDataType()).c_str()); - for (size_t i = 0; i < desc_temp.GetShape().GetDimNum(); ++i) { - GELOGD("dims[%zu]: %ld", i, desc_temp.GetShape().GetDim(i)); + output_desc->GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(output_desc->GetFormat()).c_str(), + TypeUtils::DataTypeToSerialString(output_desc->GetDataType()).c_str()); + for (size_t i = 0; i < output_desc->GetShape().GetDimNum(); ++i) { + GELOGD("dims[%zu]: %ld", i, output_desc->GetShape().GetDim(i)); } - auto input_desc = node_op_desc->GetInputDescPtr(in_data_anchor->GetIdx()); + auto input_desc = node_op_desc->MutableInputDesc(in_data_anchor->GetIdx()); GE_CHECK_NOTNULL(input_desc); - ge::TensorUtils::SetSize(const_cast(*input_desc), size); + (void) ge::TensorUtils::SetSize(*input_desc, size); GE_CHK_STATUS_RET(node_op_desc->UpdateInputDesc(in_data_anchor->GetIdx(), *input_desc)); GELOGD("%s input desc, dim_size: %zu, mem_size: %ld, format: %s, type: %s.", node_ptr->GetName().c_str(), input_desc->GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(), TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str()); + // inherit some attr + int64_t tensor_size_attr; + if (AttrUtils::GetInt(output_desc, ATTR_NAME_SPECIAL_OUTPUT_SIZE, tensor_size_attr) && (tensor_size_attr > 0)) { + GE_IF_BOOL_EXEC(!AttrUtils::SetInt(*input_desc, ATTR_NAME_SPECIAL_OUTPUT_SIZE, tensor_size_attr), + GELOGW("Set size attr failed!"); continue); + GELOGD("node[%s] [%d]th output has sepcial size[%ld], and update to node[%s] [%d]th input", + src_op->GetName().c_str(), peer_out_anchor->GetIdx(), tensor_size_attr, + node_op_desc->GetName().c_str(), in_data_anchor->GetIdx()); + } } return SUCCESS; diff --git a/ge/graph/load/new_model_manager/davinci_model.cc b/ge/graph/load/new_model_manager/davinci_model.cc index 93cb8d89..425ce199 100755 --- a/ge/graph/load/new_model_manager/davinci_model.cc +++ b/ge/graph/load/new_model_manager/davinci_model.cc @@ -2109,7 +2109,12 @@ void DavinciModel::CreateOutput(uint32_t index, OpDescPtr &op_desc, InputOutputD } int64_t tensor_size = 0; - (void)TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size); // no need to check value + if (AttrUtils::GetInt(op_desc->GetInputDescPtr(index), ATTR_NAME_SPECIAL_OUTPUT_SIZE, tensor_size) + && (tensor_size > 0)) { + GELOGI("netoutput[%s] [%d]th input has special size [%ld]", op_desc->GetName().c_str(), index, tensor_size); + } else { + (void)TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size); // no need to check value + } output.size = static_cast(tensor_size); output.data_type = op_desc->GetInputDescPtr(index)->GetDataType(); } diff --git a/metadef b/metadef index 6995fa36..cb50fa2c 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 6995fa3682b9e1147c5173e56192126d2f91a2b7 +Subproject commit cb50fa2c2141bc5bc679bc47949ed8247850406e