unsortsegment_sum infershape in dynamicshape fi

pull/9442/head
yao_yf 4 years ago
parent 747bc87ab3
commit 519f415d6b

@ -224,7 +224,7 @@ class EmbeddingLookup(Cell):
elif slice_mode == "table_row_slice" and is_auto_parallel:
if target == 'DEVICE':
indices_shape_size = 1
self.gather_revert.shard(((1, 1), (1,)))
self.gather_revert.shard(((1, 1), (get_group_size(),)))
self.forward_unique = True
indices_strategy = (1,)*indices_shape_size
self.gatherv2.shard(((get_group_size(), 1), indices_strategy))

@ -1883,18 +1883,27 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
shp = [num_segments_v]
shp += x_shp[segment_ids_shp_len:]
if 'max_shape' in x:
output_incoming = x['max_shape']
if "max_value" in num_segments and "min_value" in num_segments:
output_max_shape = list(num_segments['max_value'])
output_min_shape = list(num_segments['min_value'])
else:
if isinstance(num_segments_type, type(mstype.tensor)):
raise ValueError("In dynamic shape scene, the num_segments should contains max_value and min_value")
output_max_shape = [num_segments_v]
output_max_shape += output_incoming[segment_ids_shp_len:]
output_min_shape = [num_segments_v]
if 'max_shape' in x and 'min_shape' in x:
max_output_incoming = x['max_shape']
min_output_incoming = x['min_shape']
else:
output_max_shape = x_shp
out = {'shape': shp,
'max_shape': output_max_shape,
'min_shape': [1] * segment_ids_shp_len + x_shp[segment_ids_shp_len:],
'dtype': mstype.tensor_type(x_type.element_type()),
'value': None}
return out
max_output_incoming = x_shp
min_output_incoming = x_shp
output_max_shape += max_output_incoming[segment_ids_shp_len:]
output_min_shape += min_output_incoming[segment_ids_shp_len:]
return {'shape': shp,
'max_shape': output_max_shape,
'min_shape': output_min_shape,
'dtype': mstype.tensor_type(x_type.element_type()),
'value': None}
class UnsortedSegmentMin(PrimitiveWithCheck):
@ -2688,6 +2697,26 @@ class StridedSlice(PrimitiveWithInfer):
ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v)
value = None if all(ret_shape) else Tensor(np.array([]).reshape(ret_shape), x['dtype'].element_type())
if "max_value" in x and "min_value" in x:
validator.check_value_type("min_value", x["min_value"], [tuple, list], self.name)
validator.check_value_type("max_value", x["max_value"], [tuple, list], self.name)
max_value_np = np.array(x["max_value"])
min_value_np = np.array(x["min_value"])
slice_index = []
for begin_i, end_i, strides_i in zip(begin_v, end_v, strides_v):
s = slice(begin_i, end_i, strides_i)
slice_index.append(s)
slice_index = tuple(slice_index)
max_value_slice = max_value_np[slice_index]
min_value_slice = min_value_np[slice_index]
max_value_slice = tuple(max_value_slice.tolist())
min_value_slice = tuple(min_value_slice.tolist())
return {'shape': ret_shape,
'dtype': x['dtype'],
'value': value,
'max_value': max_value_slice,
'min_value': min_value_slice}
return {'shape': ret_shape,
'dtype': x['dtype'],
'value': value}

@ -207,8 +207,6 @@ class WideDeepModel(nn.Cell):
target = 'CPU'
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target,
slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE)
if target == 'DEVICE':
self.wide_mul.shard(((1, 1, 1), (1, 1, 1)))
if config.deep_table_slice_mode == "column_slice":
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target,
slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE)

Loading…
Cancel
Save