From 8ac5672abb73b52fbbab09019f4ad4fa7085e677 Mon Sep 17 00:00:00 2001 From: fary86 Date: Mon, 27 Jul 2020 16:48:22 +0800 Subject: [PATCH] Add support for dynamic shape --- mindspore/ccsrc/frontend/operator/ops.h | 2 + .../ccsrc/frontend/operator/prim_arrays.cc | 42 +++++++++++++++ .../pipeline/jit/static_analysis/prim.cc | 16 +++++- .../ccsrc/pipeline/jit/static_analysis/prim.h | 4 ++ mindspore/ccsrc/utils/convert_utils.cc | 18 +++++-- mindspore/ccsrc/utils/convert_utils.h | 4 +- mindspore/core/abstract/dshape.cc | 3 ++ mindspore/core/abstract/dshape.h | 10 +++- mindspore/core/abstract/utils.cc | 53 ++++++++++++++++++- mindspore/ops/_grad/grad_array_ops.py | 20 +++++++ mindspore/ops/_utils/utils.py | 5 +- mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/_grad_ops.py | 25 +++++++++ mindspore/ops/operations/array_ops.py | 24 ++++++++- mindspore/ops/primitive.py | 41 +++++++++++++- 15 files changed, 258 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/frontend/operator/ops.h b/mindspore/ccsrc/frontend/operator/ops.h index 0404140b73..d21c8540f7 100755 --- a/mindspore/ccsrc/frontend/operator/ops.h +++ b/mindspore/ccsrc/frontend/operator/ops.h @@ -113,6 +113,8 @@ inline const PrimitivePtr KPrimTransData = std::make_shared("TransDat inline const PrimitivePtr kPrimNMSWithMask = std::make_shared("NMSWithMask"); inline const PrimitivePtr kPrimPad = std::make_shared("Pad"); inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared("ArgMaxWithValue"); +inline const PrimitivePtr kPrimUnique = std::make_shared("Unique"); +inline const PrimitivePtr kPrimUniqueGrad = std::make_shared("UniqueGrad"); // NN inline const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); diff --git a/mindspore/ccsrc/frontend/operator/prim_arrays.cc b/mindspore/ccsrc/frontend/operator/prim_arrays.cc index 1ed9735307..ea0725ae6e 100644 --- a/mindspore/ccsrc/frontend/operator/prim_arrays.cc +++ b/mindspore/ccsrc/frontend/operator/prim_arrays.cc @@ -148,5 +148,47 @@ AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &pri ret->set_shape(std::make_shared(shape)); return ret; } + +AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // inputs: a 1-d Tensor + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTensorPtr input = CheckArg(op_name, args_spec_list, 0); + + auto shape = input->shape(); + if (shape->shape().size() != 1) { + MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1."; + } + std::vector ids_shape = {Shape::SHP_ANY}; + std::vector min_shape = {1}; + std::vector max_shape = shape->shape(); + auto ids = + std::make_shared(input->element(), std::make_shared(ids_shape, min_shape, max_shape)); + auto ids_idx = std::make_shared(std::make_shared(32), shape->shape()); + // outputs: ids, ids_idx + AbstractBasePtrList elements = {ids, ids_idx}; + return std::make_shared(elements); +} + +AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // inputs: a 1-d Tensor + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr dout = CheckArg(op_name, args_spec_list, 0); + CheckArgsSize(op_name + " dout", dout->elements(), 2); + auto ids = CheckArg(op_name, dout->elements(), 0); + auto ids_idx = CheckArg(op_name, dout->elements(), 1); + if (ids->shape()->shape().size() != 1) { + MS_LOG(EXCEPTION) << "Dims of dout[0] of " << op_name << "' input must be 1."; + } + if (ids_idx->shape()->shape().size() != 1) { + MS_LOG(EXCEPTION) << "Dims of dout[1] of " << op_name << "' input must be 1."; + } + + // outputs: dx + return std::make_shared(ids->element(), ids_idx->shape()); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 4f0840b4fc..15c1a36474 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include "frontend/operator/cc_implementations.h" #include "frontend/operator/ops.h" @@ -62,6 +63,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, {prim::kPrimPack, {InferImplPack, true}}, + {prim::kPrimUnique, {InferImplUnique, true}}, + {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, // Structure {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, {prim::kPrimMakeList, {InferImplMakeList, true}}, @@ -389,6 +392,14 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { if (abs_base->isa()) { auto arg_tensor = dyn_cast(abs_base); dic["shape"] = arg_tensor->shape()->shape(); + if (MsContext::GetInstance()->execution_mode() == kGraphMode) { + const auto &min_shape = arg_tensor->shape()->min_shape(); + const auto &max_shape = arg_tensor->shape()->max_shape(); + if (!min_shape.empty() && !max_shape.empty()) { + dic["min_shape"] = min_shape; + dic["max_shape"] = max_shape; + } + } dic["dtype"] = arg_tensor->BuildType(); dic["value"] = BuildValue(arg_tensor->BuildValue()); } else if (abs_base->isa()) { @@ -503,7 +514,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic if (output["value"].is_none()) { auto out_shape = output["shape"]; auto out_dtype = output["dtype"]; - return PyListDtype2AbstractTensor(out_shape, out_dtype); + py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none(); + py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none(); + + return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape); } // Convert pyobject to Value, then to AbstractValue ValuePtr converted_ret = nullptr; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h index 2f44c173d0..1ed7a2c8fa 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -244,6 +244,10 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index f5a6738c8d..70590da753 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -371,7 +371,8 @@ py::object VectorRefToPyData(const VectorRef &value_list) { return ret; } -AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj) { +AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, + const py::object &min_shape, const py::object &max_shape) { if ((py::isinstance(shape_obj) || py::isinstance(shape_obj)) && py::isinstance(type_obj)) { auto ret_vec = shape_obj.cast>(); auto ret_dtype = type_obj.cast(); @@ -382,12 +383,23 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py return abs_scalar; } AbstractBasePtr tensor = nullptr; + std::vector min_shape_vec; + std::vector max_shape_vec; + if (!min_shape.is_none()) { + min_shape_vec = min_shape.cast>(); + } + if (!max_shape.is_none()) { + max_shape_vec = max_shape.cast>(); + } + auto ret_shape = std::make_shared(ret_vec, min_shape_vec, max_shape_vec); if (ret_dtype->isa()) { auto tensor_type = type_obj.cast(); MS_EXCEPTION_IF_NULL(tensor_type); - tensor = std::make_shared(tensor_type->element(), ret_vec); + auto element = std::make_shared(kAnyValue, tensor_type->element()); + tensor = std::make_shared(element, ret_shape); } else { - tensor = std::make_shared(ret_dtype, ret_vec); + auto element = std::make_shared(kAnyValue, ret_dtype); + tensor = std::make_shared(element, ret_shape); } return tensor; } else if (py::isinstance(shape_obj) && py::isinstance(type_obj)) { diff --git a/mindspore/ccsrc/utils/convert_utils.h b/mindspore/ccsrc/utils/convert_utils.h index 3216726214..5597ae4d5e 100644 --- a/mindspore/ccsrc/utils/convert_utils.h +++ b/mindspore/ccsrc/utils/convert_utils.h @@ -47,7 +47,9 @@ bool BaseRefToInt(const ValuePtr &v, int *value); bool ValueToBool(const ValuePtr &in, bool *out); py::object ValuePtrToPyData(const ValuePtr &value); -AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj); +AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, + const py::object &min_shape = py::none(), + const py::object &max_shape = py::none()); bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args, const std::shared_ptr &ret_val); diff --git a/mindspore/core/abstract/dshape.cc b/mindspore/core/abstract/dshape.cc index 74ea1ff7bf..a2cbe0fe62 100644 --- a/mindspore/core/abstract/dshape.cc +++ b/mindspore/core/abstract/dshape.cc @@ -67,6 +67,9 @@ std::string Shape::DumpText() const { buffer << "["; for (size_t i = 0; i < shape_.size(); i++) { buffer << (i > 0 ? ", " : "") << shape_[i]; + if (shape_[i] == SHP_ANY && min_shape_.size() == shape_.size() && max_shape_.size() == shape_.size()) { + buffer << "_" << min_shape_[i] << "^" << max_shape_[i]; + } } buffer << "]"; return buffer.str(); diff --git a/mindspore/core/abstract/dshape.h b/mindspore/core/abstract/dshape.h index 3c0252b00f..4197f73ac0 100644 --- a/mindspore/core/abstract/dshape.h +++ b/mindspore/core/abstract/dshape.h @@ -74,16 +74,22 @@ class Shape : public BaseShape { (void)std::transform(list.begin(), list.end(), std::back_inserter(shape_), [](const int64_t &value) { return static_cast(value); }); } + Shape(const std::vector &list, const std::vector &min_shape, const std::vector &max_shape) + : shape_(list), min_shape_(min_shape), max_shape_(max_shape) {} ~Shape() override = default; MS_DECLARE_PARENT(Shape, BaseShape) std::string ToString() const override; std::string DumpText() const override; bool operator==(const BaseShape &other) const override; - BaseShapePtr Clone() const override { return std::make_shared(shape_); } + BaseShapePtr Clone() const override { return std::make_shared(shape_, min_shape_, max_shape_); } void Broaden() override; std::vector &shape() { return shape_; } + std::vector &min_shape() { return min_shape_; } + std::vector &max_shape() { return max_shape_; } - std::vector shape_; // use SHP_ANY to implement the any shape in python + std::vector shape_; // use SHP_ANY to implement the any shape in python + std::vector min_shape_; // record mininum length for each dynamic dimention + std::vector max_shape_; // record maximum length for each dynamic dimention }; using ShapePtr = std::shared_ptr; using ShapePtrList = std::vector; diff --git a/mindspore/core/abstract/utils.cc b/mindspore/core/abstract/utils.cc index 16497c74a9..20eeab0de5 100644 --- a/mindspore/core/abstract/utils.cc +++ b/mindspore/core/abstract/utils.cc @@ -55,15 +55,66 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) { return shape1; } std::vector dims; + bool has_dynamic_shape = false; dims.resize(shape1->shape().size()); for (std::size_t i = 0; i < shape1->shape().size(); i++) { if (shape1->shape()[i] == shape2->shape()[i]) { dims[i] = shape1->shape()[i]; + if (shape1->shape()[i] == Shape::SHP_ANY) { + has_dynamic_shape = true; + } } else { dims[i] = Shape::SHP_ANY; + has_dynamic_shape = true; } } - return std::make_shared(dims); + if (!has_dynamic_shape) { + return std::make_shared(dims); + } + // calculate dynamic shape + std::vector min_dims(dims.size()); + std::vector max_dims(dims.size()); + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] != Shape::SHP_ANY) { + min_dims[i] = max_dims[i] = dims[i]; + continue; + } + if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) { + min_dims[i] = std::min(shape1->shape()[i], shape2->shape()[i]); + max_dims[i] = std::max(shape1->shape()[i], shape2->shape()[i]); + continue; + } + if (shape1->shape()[i] == Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) { + if (shape1->min_shape().empty() || shape1->max_shape().empty()) { + MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() + << " has dynamic shape, but does not have min/max shape info."; + } + min_dims[i] = std::min(shape1->min_shape()[i], shape2->shape()[i]); + max_dims[i] = std::max(shape1->max_shape()[i], shape2->shape()[i]); + continue; + } + if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] == Shape::SHP_ANY) { + if (shape2->min_shape().empty() || shape2->max_shape().empty()) { + MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() + << " has dynamic shape, but does not have min/max shape info."; + } + min_dims[i] = std::min(shape1->shape()[i], shape2->min_shape()[i]); + max_dims[i] = std::max(shape1->shape()[i], shape2->max_shape()[i]); + continue; + } + // both shapes contains dynamic shape + if (shape1->min_shape().empty() || shape1->max_shape().empty()) { + MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() + << " has dynamic shape, but does not have min/max shape info."; + } + if (shape2->min_shape().empty() || shape2->max_shape().empty()) { + MS_EXCEPTION(ValueError) << "Shape " << shape2->ToString() + << " has dynamic shape, but does not have min/max shape info."; + } + min_dims[i] = std::min(shape1->min_shape()[i], shape2->min_shape()[i]); + max_dims[i] = std::max(shape1->max_shape()[i], shape2->max_shape()[i]); + } + return std::make_shared(dims, min_dims, max_dims); } AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) { diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index d8d3328b94..ca22806cd9 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -807,3 +807,23 @@ def get_bprop_trans_shape(self): dx = op(dout, shape_op(x)) return (dx, zeros_like(shape)) return bprop + + +@bprop_getters.register(P.Unique) +def get_bprop_unique(self): + """Generate bprop for Unique""" + op = G.UniqueGrad() + def bprop(x, out, dout): + dx = op(dout, out) + return (dx,) + return bprop + + +@bprop_getters.register(P.UnsortedSegmentSum) +def get_bprop_unsorted_segment_sum(self): + """Generate bprop for UnsortedSegmentSum""" + op = G.UnsortedSegmentSumGrad() + def bprop(x, segment_ids, num_segments, out, dout): + dx = op(dout, segment_ids) + return (dx, zeros_like(segment_ids), zeros_like(num_segments)) + return bprop diff --git a/mindspore/ops/_utils/utils.py b/mindspore/ops/_utils/utils.py index 0e6850dcb1..9ee599e6ef 100644 --- a/mindspore/ops/_utils/utils.py +++ b/mindspore/ops/_utils/utils.py @@ -82,5 +82,8 @@ def get_concat_offset(x_shp, x_type, axis, prim_name): if j != axis and v[j] != x_shp[0][j]: raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element") offset.append(all_shp) - all_shp += v[axis] + if all_shp == -1 or v[axis] == -1: + all_shp = -1 + else: + all_shp += v[axis] return offset, all_shp, axis diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 1c76b737c7..66a0c65089 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -32,7 +32,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Squeeze, StridedSlice, Tile, TensorScatterUpdate, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, - SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup) + SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, + Unique) from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice, diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 5663dadc76..1a10590a70 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -491,6 +491,31 @@ class FusedBatchNormGrad(Primitive): raise NotImplementedError +class UniqueGrad(Primitive): + """Gradients of Unique operation.""" + + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx']) + + def __call__(self, dy, x, scale, save_mean, save_inv_variance): + raise NotImplementedError + + +class UnsortedSegmentSumGrad(PrimitiveWithInfer): + """Gradients of UnsortedSegmentSum operation.""" + + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['grads', 'ids'], outputs=['y']) + + def infer_shape(self, grads, ids): + return ids + grads[len(ids):] + + def infer_dtype(self, grads, ids): + return grads + + class BNTrainingReduceGrad(PrimitiveWithInfer): """Gradients of FusedBatchNorm operation.""" diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 2cc26d61eb..734d9e3fc6 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -27,7 +27,7 @@ import numpy as np from .._utils import get_concat_offset from ..operations.math_ops import _infer_shape_reduce -from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op +from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op from ..._c_expression import signature_dtype as sig_dtype from ..._c_expression import signature_kind as sig_kind from ..._c_expression import signature_rw as sig_rw @@ -556,6 +556,28 @@ class Transpose(PrimitiveWithInfer): return out +class Unique(Primitive): + """ + Returns the unique elements of input tensor and also return a tensor containing the index of each value of input + tensor corresponding to the output unique tensor. + + Inputs: + - **x** (Tensor) - The input tensor. + + Outputs: + Tuple, containing tensor objects `(y, idx)`, `y` is a tensor has the same type as `x`, `idx` is a tensor + containing indices of elements in the input coressponding to the output tensor. + + Examples: + >>> x = Tensor(np.array([1, 2, 5, 2]), mindspore.float32) + >>> out = P.Unique()(x) + (Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.float32)) + """ + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + class GatherV2(PrimitiveWithInfer): """ Returns a slice of input tensor based on the specified indices and axis. diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 1114a6bf57..d0615ee4c8 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -20,6 +20,7 @@ import copy from mindspore.common.api import _wrap_func from mindspore.common import Parameter from mindspore.common._register_for_tensor import tensor_operator_registry +from mindspore import context from .._c_expression import Primitive_, real_run_op, prim_type from .._c_expression import signature_rw as sig_rw from .._c_expression import signature_kind as sig_kind @@ -138,6 +139,8 @@ class Primitive(Primitive_): return self def __getattr__(self, item): + if item == 'infer_dynamic_shape': + return None if item in super().get_attr_dict(): return super().get_attr_dict()[item] if item in self.attrs: @@ -282,13 +285,49 @@ class PrimitiveWithInfer(Primitive): def __infer__(self, *args): """Infer shape, type, and value at the same time by using dictionary as arguments.""" + is_graph_mode = context.get_context("mode") == context.GRAPH_MODE + fn_infer_dynamic_shape = getattr(self, 'infer_dynamic_shape', None) + if is_graph_mode and fn_infer_dynamic_shape is not None: + out = fn_infer_dynamic_shape(*args) + tracks = ['dtype', 'value'] + for track in tracks: + fn = getattr(self, 'infer_' + track) + # fn may return None + out[track] = fn(*(x[track] for x in args)) + return out + tracks = ['dtype', 'shape', 'value'] out = {} for track in tracks: fn = getattr(self, 'infer_' + track) # fn may return None out[track] = fn(*(x[track] for x in args)) - return out + + # in non-graph_mode, it is not necessary to infer min/max shape + if not is_graph_mode: + return out + + def get_specified_shape(elems, attr): + has_specified_shape = False + ret_vals = [] + for elem in elems: + if attr in elem: + has_specified_shape = True + ret_vals.append(elem[attr]) + else: + ret_vals.append(elem['shape']) + return has_specified_shape, tuple(ret_vals) + + has_min_shape, min_shapes = get_specified_shape(args, 'min_shape') + has_max_shape, max_shapes = get_specified_shape(args, 'max_shape') + if not (has_min_shape or has_max_shape): + return out + if has_min_shape and has_max_shape: + fn_infer_shape = getattr(self, 'infer_shape') + out['min_shape'] = fn_infer_shape(*min_shapes) + out['max_shape'] = fn_infer_shape(*max_shapes) + return out + raise ValueError('Input args has invalid dynamic shape, args info: {args}') def prim_attr_register(fn):