|
|
|
@ -44,10 +44,11 @@ class Embedding(Cell):
|
|
|
|
|
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(\text{vocab_size})`.
|
|
|
|
|
|
|
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The element of
|
|
|
|
|
the Tensor should be integer and not larger than vocab_size. else the corresponding embedding vector is zero
|
|
|
|
|
if larger than vocab_size.
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor of shape :math:`(\text{vocab_size}, \text{embedding_size})`.
|
|
|
|
|
Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> net = nn.Embedding(20000, 768, True)
|
|
|
|
@ -61,6 +62,7 @@ class Embedding(Cell):
|
|
|
|
|
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32):
|
|
|
|
|
super(Embedding, self).__init__()
|
|
|
|
|
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
|
|
|
|
validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
|
|
|
|
|
self.vocab_size = vocab_size
|
|
|
|
|
self.embedding_size = embedding_size
|
|
|
|
|
self.use_one_hot = use_one_hot
|
|
|
|
|