pre padding in dygraph (#30163)

Change-Id: Ia5279b0cbb6a5b3970aff66e9510e0d85efa70ce
revert-31562-mean
tangwei12 4 years ago committed by GitHub
parent 198fbdfb60
commit 4763e6bc4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,12 +26,10 @@ paddle.disable_static()
class EmbeddingDygraph(unittest.TestCase):
def test_1(self):
x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32)
paddle.disable_static(paddle.CPUPlace())
x = paddle.to_tensor(x_data, stop_gradient=False)
y = paddle.to_tensor(y_data, stop_gradient=False)
embedding = paddle.nn.Embedding(10, 3, sparse=True)
embedding = paddle.nn.Embedding(10, 3, sparse=True, padding_idx=9)
w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32)
embedding.weight.set_value(w0)

@ -16,6 +16,7 @@
import paddle
from ...fluid.dygraph import Flatten #DEFINE_ALIAS
from ...fluid.dygraph import layers
from ...fluid.framework import in_dygraph_mode
from .. import functional as F
from ...fluid.framework import _dygraph_tracer
@ -1352,6 +1353,9 @@ class Embedding(layers.Layer):
dtype=self._dtype,
is_bias=False)
if in_dygraph_mode() and padding_idx != -1:
self.weight[padding_idx] = 0.0
def forward(self, x):
return F.embedding(
x,

Loading…
Cancel
Save