|
|
|
@ -448,24 +448,32 @@ Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) {
|
|
|
|
|
auto node_op_desc = node_ptr->GetOpDesc();
|
|
|
|
|
GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue);
|
|
|
|
|
// set dst_node.input_desc = src_node.output_desc
|
|
|
|
|
ge::GeTensorDesc desc_temp(src_op->GetOutputDesc(peer_out_anchor->GetIdx()));
|
|
|
|
|
|
|
|
|
|
auto output_desc = src_op->GetOutputDescPtr(peer_out_anchor->GetIdx());
|
|
|
|
|
int64_t size = 0;
|
|
|
|
|
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc_temp, size) != SUCCESS, GELOGI("Get size failed!"));
|
|
|
|
|
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_desc, size) != SUCCESS, GELOGI("Get size failed!"));
|
|
|
|
|
GELOGD("src node %s output desc, dim_size: %zu, mem_size: %ld, format: %s, type: %s.", src_node->GetName().c_str(),
|
|
|
|
|
desc_temp.GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(desc_temp.GetFormat()).c_str(),
|
|
|
|
|
TypeUtils::DataTypeToSerialString(desc_temp.GetDataType()).c_str());
|
|
|
|
|
for (size_t i = 0; i < desc_temp.GetShape().GetDimNum(); ++i) {
|
|
|
|
|
GELOGD("dims[%zu]: %ld", i, desc_temp.GetShape().GetDim(i));
|
|
|
|
|
output_desc->GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(output_desc->GetFormat()).c_str(),
|
|
|
|
|
TypeUtils::DataTypeToSerialString(output_desc->GetDataType()).c_str());
|
|
|
|
|
for (size_t i = 0; i < output_desc->GetShape().GetDimNum(); ++i) {
|
|
|
|
|
GELOGD("dims[%zu]: %ld", i, output_desc->GetShape().GetDim(i));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto input_desc = node_op_desc->GetInputDescPtr(in_data_anchor->GetIdx());
|
|
|
|
|
auto input_desc = node_op_desc->MutableInputDesc(in_data_anchor->GetIdx());
|
|
|
|
|
GE_CHECK_NOTNULL(input_desc);
|
|
|
|
|
ge::TensorUtils::SetSize(const_cast<GeTensorDesc &>(*input_desc), size);
|
|
|
|
|
(void) ge::TensorUtils::SetSize(*input_desc, size);
|
|
|
|
|
GE_CHK_STATUS_RET(node_op_desc->UpdateInputDesc(in_data_anchor->GetIdx(), *input_desc));
|
|
|
|
|
GELOGD("%s input desc, dim_size: %zu, mem_size: %ld, format: %s, type: %s.", node_ptr->GetName().c_str(),
|
|
|
|
|
input_desc->GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(),
|
|
|
|
|
TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str());
|
|
|
|
|
// inherit some attr
|
|
|
|
|
int64_t tensor_size_attr;
|
|
|
|
|
if (AttrUtils::GetInt(output_desc, ATTR_NAME_SPECIAL_OUTPUT_SIZE, tensor_size_attr) && (tensor_size_attr > 0)) {
|
|
|
|
|
GE_IF_BOOL_EXEC(!AttrUtils::SetInt(*input_desc, ATTR_NAME_SPECIAL_OUTPUT_SIZE, tensor_size_attr),
|
|
|
|
|
GELOGW("Set size attr failed!"); continue);
|
|
|
|
|
GELOGD("node[%s] [%d]th output has sepcial size[%ld], and update to node[%s] [%d]th input",
|
|
|
|
|
src_op->GetName().c_str(), peer_out_anchor->GetIdx(), tensor_size_attr,
|
|
|
|
|
node_op_desc->GetName().c_str(), in_data_anchor->GetIdx());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|