|
|
|
@ -16,16 +16,21 @@
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
|
from mindspore.communication.management import get_group_size
|
|
|
|
|
from mindspore.context import ParallelMode
|
|
|
|
|
from mindspore.parallel._utils import _get_parallel_mode
|
|
|
|
|
from mindspore._checkparam import Rel
|
|
|
|
|
from mindspore._checkparam import Validator as validator
|
|
|
|
|
from mindspore.ops.primitive import constexpr
|
|
|
|
|
from .basic import ClipByNorm
|
|
|
|
|
from ..cell import Cell
|
|
|
|
|
|
|
|
|
|
__all__ = ['Embedding', 'EmbeddingLookup']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Embedding(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
A simple lookup table that stores embeddings of a fixed dictionary and size.
|
|
|
|
@ -45,7 +50,8 @@ class Embedding(Cell):
|
|
|
|
|
Refer to class `initializer` for the values of string when a string
|
|
|
|
|
is specified. Default: 'normal'.
|
|
|
|
|
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
|
|
|
|
|
|
|
|
|
|
padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
|
|
|
|
|
will be initialized to zero. Default: None. The feature is inactivated.
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The elements of
|
|
|
|
|
the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
|
|
|
|
@ -63,16 +69,24 @@ class Embedding(Cell):
|
|
|
|
|
>>> output.shape
|
|
|
|
|
(8, 128, 768)
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32):
|
|
|
|
|
|
|
|
|
|
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
|
|
|
|
|
dtype=mstype.float32, padding_idx=None):
|
|
|
|
|
super(Embedding, self).__init__()
|
|
|
|
|
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
|
|
|
|
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
|
|
|
|
|
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
|
|
|
|
|
validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
|
|
|
|
|
self.vocab_size = vocab_size
|
|
|
|
|
self.embedding_size = embedding_size
|
|
|
|
|
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
|
|
|
|
self.use_one_hot = use_one_hot
|
|
|
|
|
self.embedding_table = Parameter(initializer(embedding_table, [vocab_size, embedding_size]),
|
|
|
|
|
name='embedding_table')
|
|
|
|
|
self.dtype = dtype
|
|
|
|
|
self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
|
|
|
|
|
self.padding_idx = padding_idx
|
|
|
|
|
if padding_idx is not None:
|
|
|
|
|
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
|
|
|
|
|
"padding_idx", self.cls_name)
|
|
|
|
|
self.init_tensor = self.init_tensor.to_tensor().asnumpy()
|
|
|
|
|
self.init_tensor[self.padding_idx] = 0
|
|
|
|
|
self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
|
|
|
|
|
self.expand = P.ExpandDims()
|
|
|
|
|
self.reshape_flat = P.Reshape()
|
|
|
|
|
self.shp_flat = (-1,)
|
|
|
|
@ -99,16 +113,17 @@ class Embedding(Cell):
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
def extend_repr(self):
|
|
|
|
|
s = 'vocab_size={}, embedding_size={},' \
|
|
|
|
|
'use_one_hot={}, ' \
|
|
|
|
|
'embedding_table={}, dtype={}'.format(
|
|
|
|
|
self.vocab_size,
|
|
|
|
|
self.embedding_size,
|
|
|
|
|
self.use_one_hot,
|
|
|
|
|
self.embedding_table,
|
|
|
|
|
self.dtype)
|
|
|
|
|
s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
|
|
|
|
|
self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _make_axis_range(start, end):
|
|
|
|
|
axis = tuple(range(start, end))
|
|
|
|
|
return axis
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingLookup(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
Returns a slice of input tensor based on the specified indices.
|
|
|
|
@ -120,8 +135,7 @@ class EmbeddingLookup(Cell):
|
|
|
|
|
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
|
|
|
|
|
specified 'axis = 0' to lookup table.
|
|
|
|
|
In field slice mode, the manual_shapes must be given. It is a tuple ,where
|
|
|
|
|
the element is vocab[i], vocab[i] is the row numbers for i-th
|
|
|
|
|
part.
|
|
|
|
|
the element is vocab[i], vocab[i] is the row numbers for i-th part.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
vocab_size (int): Size of the dictionary of embeddings.
|
|
|
|
@ -132,6 +146,8 @@ class EmbeddingLookup(Cell):
|
|
|
|
|
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
|
|
|
|
|
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
|
|
|
|
|
manual_shapes (tuple): The accompaniment array in field slice mode.
|
|
|
|
|
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
|
|
|
|
or None. Default: None
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
|
|
@ -152,7 +168,7 @@ class EmbeddingLookup(Cell):
|
|
|
|
|
TABLE_COLUMN_SLICE = "table_column_slice"
|
|
|
|
|
|
|
|
|
|
def __init__(self, vocab_size, embedding_size, param_init='normal',
|
|
|
|
|
target='CPU', slice_mode='batch_slice', manual_shapes=None):
|
|
|
|
|
target='CPU', slice_mode='batch_slice', manual_shapes=None, max_norm=None):
|
|
|
|
|
super(EmbeddingLookup, self).__init__()
|
|
|
|
|
self.target = target
|
|
|
|
|
if target not in ('CPU', 'DEVICE'):
|
|
|
|
@ -160,7 +176,9 @@ class EmbeddingLookup(Cell):
|
|
|
|
|
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
|
|
|
|
|
self.gatherv2 = P.GatherV2()
|
|
|
|
|
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
|
|
|
|
self.embedding_table = Parameter(initializer(param_init, [vocab_size, embedding_size]),
|
|
|
|
|
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
|
|
|
|
|
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
|
|
|
|
|
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
|
|
|
|
|
name='embedding_table')
|
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
@ -188,10 +206,18 @@ class EmbeddingLookup(Cell):
|
|
|
|
|
if is_auto_parallel:
|
|
|
|
|
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
|
|
|
|
|
+ str(slice_mode))
|
|
|
|
|
self.max_norm = max_norm
|
|
|
|
|
if self.max_norm is not None:
|
|
|
|
|
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
|
|
|
|
|
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
|
|
|
|
|
|
|
|
|
|
def construct(self, indices):
|
|
|
|
|
if self.target == "CPU":
|
|
|
|
|
out = self.embeddinglookup(self.embedding_table, indices, 0)
|
|
|
|
|
else:
|
|
|
|
|
out = self.gatherv2(self.embedding_table, indices, 0)
|
|
|
|
|
if self.max_norm is not None:
|
|
|
|
|
axis = _make_axis_range(F.rank(indices), F.rank(out))
|
|
|
|
|
clip_by_norm = ClipByNorm(axis)
|
|
|
|
|
out = clip_by_norm(out, self.max_norm)
|
|
|
|
|
return out
|
|
|
|
|