|
|
|
@ -450,7 +450,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
|
|
>>> input_indices = Tensor([[2, 4, 6, 0, 0], [1, 3, 5, 0, 0]], mindspore.int32)
|
|
|
|
|
>>> input_values = Tensor([[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], mindspore.float32)
|
|
|
|
|
>>> field_ids = Tensor([[0, 1, 1, 0, 0], [0, 0, 1, 0, 0]], mindspore.int32)
|
|
|
|
|
>>> net = nn.MultiFieldEmbeddingLookup(10, 2, field_size=2, operator='SUM')
|
|
|
|
|
>>> net = nn.MultiFieldEmbeddingLookup(10, 2, field_size=2, operator='SUM', target='DEVICE')
|
|
|
|
|
>>> out = net(input_indices, input_values, field_ids)
|
|
|
|
|
>>> print(out.shape)
|
|
|
|
|
(2, 2, 2)
|
|
|
|
|