|
|
|
@ -445,14 +445,6 @@ class Reshape(PrimitiveWithInfer):
|
|
|
|
|
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
|
|
|
|
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
|
|
|
|
shape_v = list(shape_v)
|
|
|
|
|
if 'max_shape' in x:
|
|
|
|
|
x_max_shape = x['max_shape']
|
|
|
|
|
else:
|
|
|
|
|
x_max_shape = x['shape']
|
|
|
|
|
if 'min_shape' in x:
|
|
|
|
|
x_min_shape = x['min_shape']
|
|
|
|
|
else:
|
|
|
|
|
x_min_shape = x['shape']
|
|
|
|
|
neg_index = -1
|
|
|
|
|
dim_prod = 1
|
|
|
|
|
for i, shp_i in enumerate(shape_v):
|
|
|
|
@ -464,34 +456,49 @@ class Reshape(PrimitiveWithInfer):
|
|
|
|
|
else:
|
|
|
|
|
dim_prod *= shp_i
|
|
|
|
|
arr_prod = np.prod(x_shp)
|
|
|
|
|
max_arr_prod = np.prod(x_max_shape)
|
|
|
|
|
min_arr_prod = np.prod(x_min_shape)
|
|
|
|
|
if dim_prod <= 0 or arr_prod % dim_prod != 0:
|
|
|
|
|
raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.'
|
|
|
|
|
f'The product of input_x\'s shape should > 0, '
|
|
|
|
|
f'and can be divided by product of input_shape, '
|
|
|
|
|
f'but product of input_x\'s shape is {arr_prod}, product of input_shape is {dim_prod}.')
|
|
|
|
|
max_shape = list(shape_v)
|
|
|
|
|
min_shape = list(shape_v)
|
|
|
|
|
if neg_index != -1:
|
|
|
|
|
shape_v[neg_index] = int(arr_prod / dim_prod)
|
|
|
|
|
max_shape[neg_index] = int(max_arr_prod / dim_prod)
|
|
|
|
|
min_shape[neg_index] = int(min_arr_prod / dim_prod)
|
|
|
|
|
dim_prod *= shape_v[neg_index]
|
|
|
|
|
if dim_prod != arr_prod:
|
|
|
|
|
raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.'
|
|
|
|
|
f'The product of input_x\'s shape should be equal to product of input_shape, '
|
|
|
|
|
f'but product of input_x\'s shape is {arr_prod}, product of input_shape is {dim_prod}.')
|
|
|
|
|
|
|
|
|
|
value = None
|
|
|
|
|
if x['value'] is not None:
|
|
|
|
|
value = Tensor(x['value'].asnumpy().reshape(shape_v))
|
|
|
|
|
|
|
|
|
|
out = {'shape': tuple(shape_v),
|
|
|
|
|
'dtype': x['dtype'],
|
|
|
|
|
'value': value,
|
|
|
|
|
'max_shape': tuple(max_shape),
|
|
|
|
|
'min_shape': tuple(min_shape)}
|
|
|
|
|
if arr_prod <= 0:
|
|
|
|
|
if 'max_shape' in x:
|
|
|
|
|
x_max_shape = x['max_shape']
|
|
|
|
|
else:
|
|
|
|
|
x_max_shape = x['shape']
|
|
|
|
|
if 'min_shape' in x:
|
|
|
|
|
x_min_shape = x['min_shape']
|
|
|
|
|
else:
|
|
|
|
|
x_min_shape = x['shape']
|
|
|
|
|
max_arr_prod = np.prod(x_max_shape)
|
|
|
|
|
min_arr_prod = np.prod(x_min_shape)
|
|
|
|
|
max_shape = list(shape_v)
|
|
|
|
|
min_shape = list(shape_v)
|
|
|
|
|
if neg_index != -1:
|
|
|
|
|
max_shape[neg_index] = int(max_arr_prod / dim_prod)
|
|
|
|
|
min_shape[neg_index] = int(min_arr_prod / dim_prod)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f'For dynamic shape, Reshape must have neg index')
|
|
|
|
|
out = {'shape': shape['value'],
|
|
|
|
|
'dtype': x['dtype'],
|
|
|
|
|
'value': None,
|
|
|
|
|
'max_shape': tuple(max_shape),
|
|
|
|
|
'min_shape': tuple(min_shape)}
|
|
|
|
|
else:
|
|
|
|
|
if dim_prod <= 0 or arr_prod % dim_prod != 0:
|
|
|
|
|
raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.'
|
|
|
|
|
f'The product of input_x\'s shape should > 0, '
|
|
|
|
|
f'and can be divided by product of input_shape, but '
|
|
|
|
|
f'product of input_x\'s shape is {arr_prod}, product of input_shape is {dim_prod}.')
|
|
|
|
|
if neg_index != -1:
|
|
|
|
|
shape_v[neg_index] = int(arr_prod / dim_prod)
|
|
|
|
|
dim_prod *= shape_v[neg_index]
|
|
|
|
|
if dim_prod != arr_prod:
|
|
|
|
|
raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.'
|
|
|
|
|
f'The product of input_x\'s shape should be equal to product of input_shape, but '
|
|
|
|
|
f'product of input_x\'s shape is {arr_prod}, product of input_shape is {dim_prod}.')
|
|
|
|
|
value = None
|
|
|
|
|
if x['value'] is not None:
|
|
|
|
|
value = Tensor(x['value'].asnumpy().reshape(shape_v))
|
|
|
|
|
|
|
|
|
|
out = {'shape': tuple(shape_v),
|
|
|
|
|
'dtype': x['dtype'],
|
|
|
|
|
'value': value}
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -4267,6 +4274,8 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
|
|
|
|
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
|
|
|
|
|
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
|
|
|
|
params_shp = params['shape']
|
|
|
|
|
if len(params_shp) > 2:
|
|
|
|
|
raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp))
|
|
|
|
|
out_shape = indices['shape'] + params_shp[1:]
|
|
|
|
|
if 'max_shape' in indices:
|
|
|
|
|
out_max_shape = indices['max_shape'] + params_shp[1:]
|
|
|
|
|