|
|
|
@ -1883,18 +1883,27 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
|
|
|
|
shp = [num_segments_v]
|
|
|
|
|
|
|
|
|
|
shp += x_shp[segment_ids_shp_len:]
|
|
|
|
|
if 'max_shape' in x:
|
|
|
|
|
output_incoming = x['max_shape']
|
|
|
|
|
if "max_value" in num_segments and "min_value" in num_segments:
|
|
|
|
|
output_max_shape = list(num_segments['max_value'])
|
|
|
|
|
output_min_shape = list(num_segments['min_value'])
|
|
|
|
|
else:
|
|
|
|
|
if isinstance(num_segments_type, type(mstype.tensor)):
|
|
|
|
|
raise ValueError("In dynamic shape scene, the num_segments should contains max_value and min_value")
|
|
|
|
|
output_max_shape = [num_segments_v]
|
|
|
|
|
output_max_shape += output_incoming[segment_ids_shp_len:]
|
|
|
|
|
output_min_shape = [num_segments_v]
|
|
|
|
|
if 'max_shape' in x and 'min_shape' in x:
|
|
|
|
|
max_output_incoming = x['max_shape']
|
|
|
|
|
min_output_incoming = x['min_shape']
|
|
|
|
|
else:
|
|
|
|
|
output_max_shape = x_shp
|
|
|
|
|
out = {'shape': shp,
|
|
|
|
|
'max_shape': output_max_shape,
|
|
|
|
|
'min_shape': [1] * segment_ids_shp_len + x_shp[segment_ids_shp_len:],
|
|
|
|
|
'dtype': mstype.tensor_type(x_type.element_type()),
|
|
|
|
|
'value': None}
|
|
|
|
|
return out
|
|
|
|
|
max_output_incoming = x_shp
|
|
|
|
|
min_output_incoming = x_shp
|
|
|
|
|
output_max_shape += max_output_incoming[segment_ids_shp_len:]
|
|
|
|
|
output_min_shape += min_output_incoming[segment_ids_shp_len:]
|
|
|
|
|
return {'shape': shp,
|
|
|
|
|
'max_shape': output_max_shape,
|
|
|
|
|
'min_shape': output_min_shape,
|
|
|
|
|
'dtype': mstype.tensor_type(x_type.element_type()),
|
|
|
|
|
'value': None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnsortedSegmentMin(PrimitiveWithCheck):
|
|
|
|
@ -2688,6 +2697,26 @@ class StridedSlice(PrimitiveWithInfer):
|
|
|
|
|
ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v)
|
|
|
|
|
|
|
|
|
|
value = None if all(ret_shape) else Tensor(np.array([]).reshape(ret_shape), x['dtype'].element_type())
|
|
|
|
|
if "max_value" in x and "min_value" in x:
|
|
|
|
|
validator.check_value_type("min_value", x["min_value"], [tuple, list], self.name)
|
|
|
|
|
validator.check_value_type("max_value", x["max_value"], [tuple, list], self.name)
|
|
|
|
|
max_value_np = np.array(x["max_value"])
|
|
|
|
|
min_value_np = np.array(x["min_value"])
|
|
|
|
|
slice_index = []
|
|
|
|
|
for begin_i, end_i, strides_i in zip(begin_v, end_v, strides_v):
|
|
|
|
|
s = slice(begin_i, end_i, strides_i)
|
|
|
|
|
slice_index.append(s)
|
|
|
|
|
slice_index = tuple(slice_index)
|
|
|
|
|
max_value_slice = max_value_np[slice_index]
|
|
|
|
|
min_value_slice = min_value_np[slice_index]
|
|
|
|
|
max_value_slice = tuple(max_value_slice.tolist())
|
|
|
|
|
min_value_slice = tuple(min_value_slice.tolist())
|
|
|
|
|
return {'shape': ret_shape,
|
|
|
|
|
'dtype': x['dtype'],
|
|
|
|
|
'value': value,
|
|
|
|
|
'max_value': max_value_slice,
|
|
|
|
|
'min_value': min_value_slice}
|
|
|
|
|
|
|
|
|
|
return {'shape': ret_shape,
|
|
|
|
|
'dtype': x['dtype'],
|
|
|
|
|
'value': value}
|
|
|
|
|