|
|
|
@ -20,6 +20,7 @@ import copy
|
|
|
|
|
from mindspore.common.api import _wrap_func
|
|
|
|
|
from mindspore.common import Parameter
|
|
|
|
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from .._c_expression import Primitive_, real_run_op, prim_type
|
|
|
|
|
from .._c_expression import signature_rw as sig_rw
|
|
|
|
|
from .._c_expression import signature_kind as sig_kind
|
|
|
|
@ -138,6 +139,8 @@ class Primitive(Primitive_):
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __getattr__(self, item):
|
|
|
|
|
if item == 'infer_dynamic_shape':
|
|
|
|
|
return None
|
|
|
|
|
if item in super().get_attr_dict():
|
|
|
|
|
return super().get_attr_dict()[item]
|
|
|
|
|
if item in self.attrs:
|
|
|
|
@ -282,13 +285,49 @@ class PrimitiveWithInfer(Primitive):
|
|
|
|
|
|
|
|
|
|
def __infer__(self, *args):
|
|
|
|
|
"""Infer shape, type, and value at the same time by using dictionary as arguments."""
|
|
|
|
|
is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
|
|
|
|
|
fn_infer_dynamic_shape = getattr(self, 'infer_dynamic_shape', None)
|
|
|
|
|
if is_graph_mode and fn_infer_dynamic_shape is not None:
|
|
|
|
|
out = fn_infer_dynamic_shape(*args)
|
|
|
|
|
tracks = ['dtype', 'value']
|
|
|
|
|
for track in tracks:
|
|
|
|
|
fn = getattr(self, 'infer_' + track)
|
|
|
|
|
# fn may return None
|
|
|
|
|
out[track] = fn(*(x[track] for x in args))
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
tracks = ['dtype', 'shape', 'value']
|
|
|
|
|
out = {}
|
|
|
|
|
for track in tracks:
|
|
|
|
|
fn = getattr(self, 'infer_' + track)
|
|
|
|
|
# fn may return None
|
|
|
|
|
out[track] = fn(*(x[track] for x in args))
|
|
|
|
|
|
|
|
|
|
# in non-graph_mode, it is not necessary to infer min/max shape
|
|
|
|
|
if not is_graph_mode:
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def get_specified_shape(elems, attr):
|
|
|
|
|
has_specified_shape = False
|
|
|
|
|
ret_vals = []
|
|
|
|
|
for elem in elems:
|
|
|
|
|
if attr in elem:
|
|
|
|
|
has_specified_shape = True
|
|
|
|
|
ret_vals.append(elem[attr])
|
|
|
|
|
else:
|
|
|
|
|
ret_vals.append(elem['shape'])
|
|
|
|
|
return has_specified_shape, tuple(ret_vals)
|
|
|
|
|
|
|
|
|
|
has_min_shape, min_shapes = get_specified_shape(args, 'min_shape')
|
|
|
|
|
has_max_shape, max_shapes = get_specified_shape(args, 'max_shape')
|
|
|
|
|
if not (has_min_shape or has_max_shape):
|
|
|
|
|
return out
|
|
|
|
|
if has_min_shape and has_max_shape:
|
|
|
|
|
fn_infer_shape = getattr(self, 'infer_shape')
|
|
|
|
|
out['min_shape'] = fn_infer_shape(*min_shapes)
|
|
|
|
|
out['max_shape'] = fn_infer_shape(*max_shapes)
|
|
|
|
|
return out
|
|
|
|
|
raise ValueError('Input args has invalid dynamic shape, args info: {args}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prim_attr_register(fn):
|
|
|
|
|