|
|
|
@ -18,10 +18,14 @@ from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
|
from mindspore._checkparam import Validator
|
|
|
|
|
from mindspore.communication.management import get_group_size
|
|
|
|
|
from mindspore.train.parallel_utils import ParallelMode
|
|
|
|
|
from mindspore.parallel._utils import _get_parallel_mode
|
|
|
|
|
from ..cell import Cell
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
from ..._checkparam import Validator as validator, Rel
|
|
|
|
|
|
|
|
|
|
__all__ = ['Embedding', 'EmbeddingLookup']
|
|
|
|
|
__all__ = ['Embedding', 'EmbeddingLookup', 'EmbeddingLookUpSplitMode']
|
|
|
|
|
|
|
|
|
|
class Embedding(Cell):
|
|
|
|
|
r"""
|
|
|
|
@ -114,29 +118,36 @@ class EmbeddingLookup(Cell):
|
|
|
|
|
When 'target' is set to 'CPU', this module will use
|
|
|
|
|
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
|
|
|
|
|
specified 'offset = 0' to lookup table.
|
|
|
|
|
when 'target' is set to 'DEVICE', this module will use P.GatherV2() which
|
|
|
|
|
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
|
|
|
|
|
specified 'axis = 0' to lookup table.
|
|
|
|
|
In field slice mode, the manual_shapes should be given. It is a tuple ,where
|
|
|
|
|
the element is (vocab[i], offset[i]), vocab[i] is the row numbers for i-th
|
|
|
|
|
part and offset[i] is the feature id offset for i-th part. The feature id in
|
|
|
|
|
i-th part will be subtracted by offset[i] to ensure the id start from 0.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
vocab_size (int): Size of the dictionary of embeddings.
|
|
|
|
|
embedding_size (int): The size of each embedding vector.
|
|
|
|
|
param_init (str): The initialize way of embedding table. Default: 'normal'.
|
|
|
|
|
target (str): Specify the target where the op is executed. Default: 'CPU'.
|
|
|
|
|
slice_mode (str): The slicing way in semi auto parallel/auto parallel. Default: 'batch_slice'.
|
|
|
|
|
manual_shapes (tuple): The accompaniment array in field slice mode.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
|
|
|
The Tensor slice, instead of the entire Tensor.
|
|
|
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
|
|
|
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
|
|
|
|
|
and the exceeding part will be filled with 0 in the output.
|
|
|
|
|
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
|
|
|
|
|
and the exceeding part will be filled with 0 in the output. Input_indices should only be a 2d tensor in
|
|
|
|
|
this interface.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
|
|
|
|
|
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
|
|
|
|
|
>>> out = nn.EmbeddingLookup()(input_params, input_indices)
|
|
|
|
|
[[[10, 11], [8 ,9]], [[14, 15], [12, 13]]]
|
|
|
|
|
>>> out = nn.EmbeddingLookup(4,2)(input_indices)
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, target='CPU'):
|
|
|
|
|
def __init__(self, vocab_size, embedding_size, param_init='normal',
|
|
|
|
|
target='CPU', slice_mode='batch_slice', manual_shapes=None):
|
|
|
|
|
super(EmbeddingLookup, self).__init__()
|
|
|
|
|
self.target = target
|
|
|
|
|
if target not in ('CPU', 'DEVICE'):
|
|
|
|
@ -144,10 +155,60 @@ class EmbeddingLookup(Cell):
|
|
|
|
|
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
|
|
|
|
|
self.gatherv2 = P.GatherV2()
|
|
|
|
|
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
|
|
|
|
self.embedding_table = Parameter(initializer(param_init, [vocab_size, embedding_size]),
|
|
|
|
|
name='embedding_table')
|
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
|
if slice_mode == EmbeddingLookUpSplitMode.FIELD_SLICE and is_auto_parallel:
|
|
|
|
|
if not manual_shapes:
|
|
|
|
|
raise ValueError("in slice field mode, the manual_shapes should not be none")
|
|
|
|
|
if not isinstance(manual_shapes, tuple):
|
|
|
|
|
raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes)))
|
|
|
|
|
for dim in manual_shapes:
|
|
|
|
|
Validator.check_integer('manul shape dim', dim, 0, Rel.GT, self.cls_name)
|
|
|
|
|
self.gatherv2.add_prim_attr("manual_split", manual_shapes)
|
|
|
|
|
self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
|
|
|
|
|
self.gatherv2.set_strategy(((get_group_size(), 1), (1, get_group_size())))
|
|
|
|
|
self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, get_group_size())))
|
|
|
|
|
elif slice_mode == EmbeddingLookUpSplitMode.TABLE_ROW_SLICE and is_auto_parallel:
|
|
|
|
|
self.gatherv2.set_strategy(((get_group_size(), 1), (1, 1)))
|
|
|
|
|
self.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1)))
|
|
|
|
|
elif slice_mode == EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE and is_auto_parallel:
|
|
|
|
|
self.gatherv2.set_strategy(((1, get_group_size()), (1, 1)))
|
|
|
|
|
self.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1)))
|
|
|
|
|
elif slice_mode == EmbeddingLookUpSplitMode.BATCH_SLICE and is_auto_parallel:
|
|
|
|
|
self.gatherv2.set_strategy(((1, 1), (get_group_size(), 1)))
|
|
|
|
|
self.embeddinglookup.set_strategy(((1, 1), (get_group_size(), 1)))
|
|
|
|
|
else:
|
|
|
|
|
if is_auto_parallel:
|
|
|
|
|
raise ValueError("slice_mode should support mode in nn.EmbeddingLookUpSplitMode, but get "
|
|
|
|
|
+ str(slice_mode))
|
|
|
|
|
|
|
|
|
|
def construct(self, params, indices):
|
|
|
|
|
def construct(self, indices):
|
|
|
|
|
if self.target == "CPU":
|
|
|
|
|
out = self.embeddinglookup(params, indices, 0)
|
|
|
|
|
out = self.embeddinglookup(self.embedding_table, indices, 0)
|
|
|
|
|
else:
|
|
|
|
|
out = self.gatherv2(params, indices, 0)
|
|
|
|
|
out = self.gatherv2(self.embedding_table, indices, 0)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingLookUpSplitMode:
|
|
|
|
|
"""
|
|
|
|
|
EmbeddingLookUp slice options in auto parallel and semi auto parallel mode.
|
|
|
|
|
|
|
|
|
|
There are five kinds of slice options, "BATCH_SLICE", "FIELD_SLICE",
|
|
|
|
|
"TABLE_ROW_SLICE" and "TABLE_COLUMN_SLICE". Default: "BATCH_SLICE".
|
|
|
|
|
|
|
|
|
|
- BATCH_SLICE: Slicing batch dimensions of indices.
|
|
|
|
|
- FIELD_SLICE: Slicing field dimensions of indices.
|
|
|
|
|
- TABLE_ROW_SLICE: Slicing row of table.
|
|
|
|
|
- TABLE_COLUMN_SLICE: Slicing column of table.
|
|
|
|
|
|
|
|
|
|
MODE_LIST: The list for all supported parallel modes.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
BATCH_SLICE = "batch_slice"
|
|
|
|
|
FIELD_SLICE = "field_slice"
|
|
|
|
|
TABLE_ROW_SLICE = "table_row_slice"
|
|
|
|
|
TABLE_COLUMN_SLICE = "table_column_slice"
|
|
|
|
|
MODE_LIST = [BATCH_SLICE, FIELD_SLICE, TABLE_ROW_SLICE, TABLE_COLUMN_SLICE]
|
|
|
|
|