From 17428ef7a8e84ebb1d66cb43c516667d1a8ad915 Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 26 Dec 2020 14:09:45 +0800 Subject: [PATCH] Fix storage bug. --- ge/generator/ge_generator.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 01f02811..38d422eb 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -265,7 +265,7 @@ static Status CheckShapeReset(const OpDescPtr &op_desc, bool &change_shape_flag) return SUCCESS; } -static void ResetTensorVecShape(const vector &inputs, vector &inputs_dynamic) { +static Status ResetTensorVecShape(const vector &inputs, vector &inputs_dynamic) { for (auto input : inputs) { auto input_desc = input.GetTensorDesc(); GeShape shape_ori = input_desc.GetShape(); @@ -280,6 +280,12 @@ static void ResetTensorVecShape(const vector &inputs, vector bool is_const = false; (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const); if (!is_const && shape_ori.GetDims().size() > 0) { + int64_t storage_format = FORMAT_NCHW; + if (ge::AttrUtils::GetInt(desc, ge::ATTR_NAME_STORAGE_FORMAT, storage_format) && + !ge::AttrUtils::SetListInt(desc, ge::ATTR_NAME_STORAGE_SHAPE, dynamic_shape_dims)) { + GELOGE(FAILED, "Set attr ATTR_NAME_STORAGE_SHAPE fail."); + return FAILED; + } desc.SetShape(dynamic_shape); desc.SetShapeRange(dynamic_shape_range); } @@ -287,6 +293,7 @@ static void ResetTensorVecShape(const vector &inputs, vector inputTensor.SetTensorDesc(desc); inputs_dynamic.push_back(inputTensor); } + return SUCCESS; } class GeGenerator::Impl { @@ -688,8 +695,8 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in if (CheckShapeReset(op_desc, dynamic_flag) == SUCCESS && dynamic_flag) { vector inputs_dynamic; vector outputs_dynamic; - ResetTensorVecShape(inputs, inputs_dynamic); - ResetTensorVecShape(outputs, outputs_dynamic); + GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(inputs, inputs_dynamic)); + GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(outputs, outputs_dynamic)); GE_CHK_STATUS_RET_NOLOG( impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs_dynamic)); } else {