|
|
@ -150,6 +150,7 @@ class EmbeddingLookup(Cell):
|
|
|
|
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
|
|
|
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
|
|
|
or None. Default: None
|
|
|
|
or None. Default: None
|
|
|
|
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
|
|
|
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
|
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
Inputs:
|
|
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
|
|
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
|
|
|
|
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
|
|
|
@ -191,6 +192,12 @@ class EmbeddingLookup(Cell):
|
|
|
|
name='embedding_table')
|
|
|
|
name='embedding_table')
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
|
|
|
|
self.forward_unique = False
|
|
|
|
|
|
|
|
self.gather_revert = P.GatherV2()
|
|
|
|
|
|
|
|
self.unique = P.Unique().shard(((1,),))
|
|
|
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
|
|
|
|
self.shape = P.Shape()
|
|
|
|
|
|
|
|
indices_shape_size = 2
|
|
|
|
if slice_mode == "field_slice" and is_auto_parallel:
|
|
|
|
if slice_mode == "field_slice" and is_auto_parallel:
|
|
|
|
if not manual_shapes:
|
|
|
|
if not manual_shapes:
|
|
|
|
raise ValueError("in slice field mode, the manual_shapes should not be none")
|
|
|
|
raise ValueError("in slice field mode, the manual_shapes should not be none")
|
|
|
@ -203,18 +210,32 @@ class EmbeddingLookup(Cell):
|
|
|
|
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
|
|
|
|
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
|
|
|
|
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
|
|
|
|
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
|
|
|
|
elif slice_mode == "table_row_slice" and is_auto_parallel:
|
|
|
|
elif slice_mode == "table_row_slice" and is_auto_parallel:
|
|
|
|
self.gatherv2.shard(((get_group_size(), 1), (1, 1)))
|
|
|
|
if target == 'DEVICE':
|
|
|
|
self.embeddinglookup.shard(((get_group_size(), 1), (1, 1)))
|
|
|
|
indices_shape_size = 1
|
|
|
|
|
|
|
|
self.gather_revert.shard(((1, 1), (1,)))
|
|
|
|
|
|
|
|
self.forward_unique = True
|
|
|
|
|
|
|
|
indices_strategy = (1,)*indices_shape_size
|
|
|
|
|
|
|
|
self.gatherv2.shard(((get_group_size(), 1), indices_strategy))
|
|
|
|
|
|
|
|
self.embeddinglookup.shard(((get_group_size(), 1), indices_strategy))
|
|
|
|
elif slice_mode == "table_column_slice" and is_auto_parallel:
|
|
|
|
elif slice_mode == "table_column_slice" and is_auto_parallel:
|
|
|
|
self.gatherv2.shard(((1, get_group_size()), (1, 1)))
|
|
|
|
if target == 'DEVICE':
|
|
|
|
self.embeddinglookup.shard(((1, get_group_size()), (1, 1)))
|
|
|
|
indices_shape_size = 1
|
|
|
|
|
|
|
|
self.gather_revert.shard(((1, get_group_size()), (1,)))
|
|
|
|
|
|
|
|
self.forward_unique = True
|
|
|
|
|
|
|
|
indices_strategy = (1,)*indices_shape_size
|
|
|
|
|
|
|
|
self.gatherv2.shard(((1, get_group_size()), indices_strategy))
|
|
|
|
|
|
|
|
self.embeddinglookup.shard(((1, get_group_size()), indices_strategy))
|
|
|
|
elif slice_mode == "batch_slice" and is_auto_parallel:
|
|
|
|
elif slice_mode == "batch_slice" and is_auto_parallel:
|
|
|
|
self.gatherv2.shard(((1, 1), (get_group_size(), 1)))
|
|
|
|
indices_strategy = [get_group_size()]
|
|
|
|
self.embeddinglookup.shard(((1, 1), (get_group_size(), 1)))
|
|
|
|
indices_strategy.extend([1]*(indices_shape_size - 1))
|
|
|
|
|
|
|
|
indices_strategy = tuple(indices_strategy)
|
|
|
|
|
|
|
|
self.gatherv2.shard(((1, 1), indices_strategy))
|
|
|
|
|
|
|
|
self.embeddinglookup.shard(((1, 1), indices_strategy))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if is_auto_parallel:
|
|
|
|
if is_auto_parallel:
|
|
|
|
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
|
|
|
|
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
|
|
|
|
+ str(slice_mode))
|
|
|
|
+ str(slice_mode))
|
|
|
|
|
|
|
|
self.embedding_table.unique = self.forward_unique
|
|
|
|
self.max_norm = max_norm
|
|
|
|
self.max_norm = max_norm
|
|
|
|
if self.max_norm is not None:
|
|
|
|
if self.max_norm is not None:
|
|
|
|
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
|
|
|
|
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
|
|
|
@ -224,7 +245,15 @@ class EmbeddingLookup(Cell):
|
|
|
|
if self.target == "CPU":
|
|
|
|
if self.target == "CPU":
|
|
|
|
out = self.embeddinglookup(self.embedding_table, indices, 0)
|
|
|
|
out = self.embeddinglookup(self.embedding_table, indices, 0)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
out = self.gatherv2(self.embedding_table, indices, 0)
|
|
|
|
if self.forward_unique:
|
|
|
|
|
|
|
|
shp = self.shape(indices) + (self.embedding_size,)
|
|
|
|
|
|
|
|
indices_flatten = self.reshape(indices, (-1,))
|
|
|
|
|
|
|
|
unique_id, unique_idx = self.unique(indices_flatten)
|
|
|
|
|
|
|
|
weight_unique = self.gatherv2(unique_id)
|
|
|
|
|
|
|
|
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
|
|
|
|
|
|
|
|
out = self.reshape(weight_flatten, shp)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
out = self.gatherv2(self.embedding_table, indices, 0)
|
|
|
|
if self.max_norm is not None:
|
|
|
|
if self.max_norm is not None:
|
|
|
|
axis = _make_axis_range(F.rank(indices), F.rank(out))
|
|
|
|
axis = _make_axis_range(F.rank(indices), F.rank(out))
|
|
|
|
clip_by_norm = ClipByNorm(axis)
|
|
|
|
clip_by_norm = ClipByNorm(axis)
|
|
|
|