Add prim name to error message for _grad_ops.py

pull/375/head
fary86 5 years ago
parent c803569648
commit 8bb93411f3

@ -206,8 +206,8 @@ class Validator:
def _check_tensor_type(arg):
arg_key, arg_val = arg
elem_type = arg_val
type_names = []
if not elem_type in valid_values:
type_names = []
for t in valid_values:
type_names.append(str(t))
types_info = '[' + ", ".join(type_names) + ']'

@ -15,7 +15,7 @@
"""utils for operator"""
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
@ -62,25 +62,25 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
return broadcast_shape
def _get_concat_offset(x_shp, x_type, axis):
def _get_concat_offset(x_shp, x_type, axis, prim_name):
"""for concat and concatoffset check args and compute offset"""
validator.check_type("shape", x_shp, [tuple])
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
validator.check_subclass("shape0", x_type[0], mstype.tensor)
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
validator.check_value_type("shape", x_shp, [tuple], prim_name)
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name)
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name)
rank_base = len(x_shp[0])
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
if axis < 0:
axis = axis + rank_base
all_shp = x_shp[0][axis]
offset = [0,]
for i in range(1, len(x_shp)):
v = x_shp[i]
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)
validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name)
for j in range(rank_base):
if j != axis and v[j] != x_shp[0][j]:
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element")
offset.append(all_shp)
all_shp += v[axis]
return offset, all_shp, axis

File diff suppressed because it is too large Load Diff

@ -1316,7 +1316,7 @@ class Concat(PrimitiveWithInfer):
axis = self.axis
x_shp = input_x['shape']
x_type = input_x['dtype']
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis)
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name)
self.add_prim_attr('T', x_type[0].element_type())
self.add_prim_attr('inputNums', len(x_shp))
ret_shp = x_shp[0].copy()

Loading…
Cancel
Save