|
|
@ -15,7 +15,7 @@
|
|
|
|
"""embedding"""
|
|
|
|
"""embedding"""
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
from mindspore import log as logger
|
|
|
|
from mindspore import log as logger
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from mindspore.common.tensor import Tensor, MetaTensor
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
@ -101,8 +101,11 @@ class Embedding(Cell):
|
|
|
|
if padding_idx is not None:
|
|
|
|
if padding_idx is not None:
|
|
|
|
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
|
|
|
|
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
|
|
|
|
"padding_idx", self.cls_name)
|
|
|
|
"padding_idx", self.cls_name)
|
|
|
|
self.init_tensor = self.init_tensor.to_tensor().asnumpy()
|
|
|
|
if isinstance(self.init_tensor, MetaTensor):
|
|
|
|
|
|
|
|
self.init_tensor = self.init_tensor.to_tensor()
|
|
|
|
|
|
|
|
self.init_tensor = self.init_tensor.asnumpy()
|
|
|
|
self.init_tensor[self.padding_idx] = 0
|
|
|
|
self.init_tensor[self.padding_idx] = 0
|
|
|
|
|
|
|
|
self.init_tensor = Tensor(self.init_tensor)
|
|
|
|
self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
|
|
|
|
self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
|
|
|
|
self.expand = P.ExpandDims()
|
|
|
|
self.expand = P.ExpandDims()
|
|
|
|
self.reshape_flat = P.Reshape()
|
|
|
|
self.reshape_flat = P.Reshape()
|
|
|
|