!751 Fix bug of modify output shape to -2.

From: @zhao_zhixuan
Reviewed-by: @ji_chen,@xchu42
Signed-off-by: @xchu42
pull/751/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 51314c970b

@ -262,6 +262,15 @@ static Status CheckShapeReset(const OpDescPtr &op_desc, bool &change_shape_flag)
change_shape_flag = true; change_shape_flag = true;
} }
} }
for (size_t i = 0; i < op_desc->GetAllOutputsDesc().size(); i++) {
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(i));
GE_CHECK_NOTNULL(output_desc);
// pass scalar output desc
auto dims = output_desc->GetShape().GetDims();
if (dims.size() == kDynamicDimSize && dims[0] == kDynamicDimValue) {
change_shape_flag = true;
}
}
return SUCCESS; return SUCCESS;
} }

@ -113,16 +113,13 @@ Status DynamicSingleOpResetShapePass::ResetOpShape(OpDescPtr &op_desc) {
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
std::vector<int64_t> dynamic_shape_dims = {kDynamicShapeDim}; std::vector<int64_t> dynamic_shape_dims = {kDynamicShapeDim};
GeShape dynamic_shape(dynamic_shape_dims); GeShape dynamic_shape(dynamic_shape_dims);
bool reset_shape_flag = false; (void)ResetInputTensorShape(op_desc, dynamic_shape);
if (ResetInputTensorShape(op_desc, dynamic_shape, reset_shape_flag) == SUCCESS && reset_shape_flag) { (void)ResetOutputTensorShape(op_desc, dynamic_shape);
(void)ResetOutputTensorShape(op_desc, dynamic_shape);
}
return SUCCESS; return SUCCESS;
} }
Status DynamicSingleOpResetShapePass::ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape, Status DynamicSingleOpResetShapePass::ResetInputTensorShape(OpDescPtr &op_desc,
bool &reset_shape_flag) { const GeShape &dynamic_shape) {
reset_shape_flag = false;
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) { for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) {
auto input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(i)); auto input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(i));
@ -136,7 +133,6 @@ Status DynamicSingleOpResetShapePass::ResetInputTensorShape(OpDescPtr &op_desc,
if (CheckIfConstInput(input_desc)) { if (CheckIfConstInput(input_desc)) {
continue; continue;
} }
reset_shape_flag = true;
input_desc->SetShape(dynamic_shape); input_desc->SetShape(dynamic_shape);
} }
return SUCCESS; return SUCCESS;

@ -27,7 +27,7 @@ class DynamicSingleOpResetShapePass : public GraphPass {
private: private:
Status ResetOpShape(OpDescPtr &op_desc); Status ResetOpShape(OpDescPtr &op_desc);
Status ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape, bool &reset_shape_flag); Status ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape);
Status ResetOutputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape); Status ResetOutputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape);
Status CheckAllAicpuNodes(const ComputeGraphPtr &graph, bool &is_not_aicpu); Status CheckAllAicpuNodes(const ComputeGraphPtr &graph, bool &is_not_aicpu);
bool CheckIfConstInput(const GeTensorDescPtr &input_tensor_desc); bool CheckIfConstInput(const GeTensorDescPtr &input_tensor_desc);

Loading…
Cancel
Save