|
|
|
@ -59,7 +59,7 @@ class Embedding(Cell):
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
vocab_size (int): Size of the dictionary of embeddings.
|
|
|
|
|
embedding_size (int): The size of each embedding vector.
|
|
|
|
|
embedding_size (Union[int, tuple(int), list(int)]): The size of each embedding vector.
|
|
|
|
|
use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: False.
|
|
|
|
|
embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
|
|
|
|
|
Refer to class `initializer` for the values of string when a string
|
|
|
|
@ -98,12 +98,21 @@ class Embedding(Cell):
|
|
|
|
|
dtype=mstype.float32, padding_idx=None):
|
|
|
|
|
super(Embedding, self).__init__()
|
|
|
|
|
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
|
|
|
|
|
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
|
|
|
|
|
self.embedding_size = validator.check_value_type('embedding_size', embedding_size,
|
|
|
|
|
[int, tuple, list], self.cls_name)
|
|
|
|
|
validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
|
|
|
|
|
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
|
|
|
|
self.use_one_hot = use_one_hot
|
|
|
|
|
self.dtype = dtype
|
|
|
|
|
self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
|
|
|
|
|
if isinstance(self.embedding_size, int):
|
|
|
|
|
self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
|
|
|
|
|
self.embedding_out = (self.embedding_size,)
|
|
|
|
|
else:
|
|
|
|
|
if len(self.embedding_size) != 2:
|
|
|
|
|
raise ValueError("embedding_size should be a int or a tuple of two ints")
|
|
|
|
|
self.init_tensor = initializer(embedding_table, [vocab_size, self.embedding_size[0],
|
|
|
|
|
self.embedding_size[1]])
|
|
|
|
|
self.embedding_out = (self.embedding_size[0], self.embedding_size[1],)
|
|
|
|
|
self.padding_idx = padding_idx
|
|
|
|
|
if padding_idx is not None:
|
|
|
|
|
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
|
|
|
|
@ -127,7 +136,7 @@ class Embedding(Cell):
|
|
|
|
|
|
|
|
|
|
def construct(self, ids):
|
|
|
|
|
extended_ids = self.expand(ids, -1)
|
|
|
|
|
out_shape = self.get_shp(ids) + (self.embedding_size,)
|
|
|
|
|
out_shape = self.get_shp(ids) + self.embedding_out
|
|
|
|
|
flat_ids = self.reshape_flat(extended_ids, self.shp_flat)
|
|
|
|
|
|
|
|
|
|
if self.use_one_hot:
|
|
|
|
|