From 23476a783e6e4d7cd1d19ffcc793d3f4983c2f7e Mon Sep 17 00:00:00 2001 From: caifubi Date: Fri, 2 Apr 2021 17:25:57 +0800 Subject: [PATCH] Add IsDynamic to BaseShape and optimize PyNative operf --- .../ccsrc/pipeline/pynative/pynative_execute.cc | 4 ++-- mindspore/core/abstract/dshape.h | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 17947d256f..91e845da43 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -850,8 +850,8 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, MS_EXCEPTION_IF_NULL(abstract); auto shape = abstract->BuildShape(); MS_EXCEPTION_IF_NULL(shape); - auto shape_info = shape->ToString(); - if (shape_info.find("-1") != string::npos) { + + if (shape->IsDynamic()) { op_exec_info->is_dynamic_shape = true; } } diff --git a/mindspore/core/abstract/dshape.h b/mindspore/core/abstract/dshape.h index 1f2f7836ac..04424937bd 100644 --- a/mindspore/core/abstract/dshape.h +++ b/mindspore/core/abstract/dshape.h @@ -46,6 +46,7 @@ class BaseShape : public Base { virtual bool operator==(const BaseShape &other) const; bool operator!=(const BaseShape &other) const; std::size_t hash() const override { return tid(); } + virtual bool IsDynamic() const = 0; // return a deep copy virtual BaseShapePtr Clone() const = 0; @@ -57,6 +58,7 @@ class NoShape : public BaseShape { MS_DECLARE_PARENT(NoShape, BaseShape) BaseShapePtr Clone() const override { return std::make_shared(); } std::string ToString() const override { return type_name(); } + bool IsDynamic() const override { return false; } }; extern const std::shared_ptr kNoShape; @@ -78,10 +80,13 @@ class Shape : public BaseShape { ShapeVector &shape() { return shape_; } ShapeVector &min_shape() { return min_shape_; } ShapeVector &max_shape() { return max_shape_; } + bool IsDynamic() const override { + return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < 0; }); + } ShapeVector shape_; // use SHP_ANY to implement the any shape in python - ShapeVector min_shape_; // record mininum length for each dynamic dimention - ShapeVector max_shape_; // record maximum length for each dynamic dimention + ShapeVector min_shape_; // record minimum length for each dynamic dimension + ShapeVector max_shape_; // record maximum length for each dynamic dimension }; using ShapePtr = std::shared_ptr; using ShapePtrList = std::vector; @@ -102,6 +107,9 @@ class SequeueShape : public BaseShape { const BaseShapePtrList &shape() const { return p_shapes_; } size_t size() const { return p_shapes_.size(); } const BaseShapePtr operator[](std::size_t dim) const { return p_shapes_[dim]; } + bool IsDynamic() const override { + return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDynamic(); }); + } protected: BaseShapePtrList p_shapes_; // shape list of each elements