|
|
|
@ -15,7 +15,7 @@
|
|
|
|
|
|
|
|
|
|
"""utils for operator"""
|
|
|
|
|
|
|
|
|
|
from ..._checkparam import ParamValidator as validator
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
|
|
|
|
@ -62,25 +62,25 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
|
|
|
|
return broadcast_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_concat_offset(x_shp, x_type, axis):
|
|
|
|
|
def _get_concat_offset(x_shp, x_type, axis, prim_name):
|
|
|
|
|
"""for concat and concatoffset check args and compute offset"""
|
|
|
|
|
validator.check_type("shape", x_shp, [tuple])
|
|
|
|
|
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
|
|
|
|
|
validator.check_subclass("shape0", x_type[0], mstype.tensor)
|
|
|
|
|
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
|
|
|
|
|
validator.check_value_type("shape", x_shp, [tuple], prim_name)
|
|
|
|
|
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name)
|
|
|
|
|
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
|
|
|
|
|
validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name)
|
|
|
|
|
rank_base = len(x_shp[0])
|
|
|
|
|
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
|
|
|
|
|
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
|
|
|
|
|
if axis < 0:
|
|
|
|
|
axis = axis + rank_base
|
|
|
|
|
all_shp = x_shp[0][axis]
|
|
|
|
|
offset = [0,]
|
|
|
|
|
for i in range(1, len(x_shp)):
|
|
|
|
|
v = x_shp[i]
|
|
|
|
|
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
|
|
|
|
|
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
|
|
|
|
|
validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)
|
|
|
|
|
validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name)
|
|
|
|
|
for j in range(rank_base):
|
|
|
|
|
if j != axis and v[j] != x_shp[0][j]:
|
|
|
|
|
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
|
|
|
|
|
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element")
|
|
|
|
|
offset.append(all_shp)
|
|
|
|
|
all_shp += v[axis]
|
|
|
|
|
return offset, all_shp, axis
|
|
|
|
|