|
|
|
@ -306,7 +306,8 @@ def embedding(input,
|
|
|
|
|
is_distributed=False,
|
|
|
|
|
padding_idx=None,
|
|
|
|
|
param_attr=None,
|
|
|
|
|
dtype='float32'):
|
|
|
|
|
dtype='float32',
|
|
|
|
|
remote_prefetch=False):
|
|
|
|
|
"""
|
|
|
|
|
**Embedding Layer**
|
|
|
|
|
|
|
|
|
@ -345,7 +346,7 @@ def embedding(input,
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
helper = LayerHelper('embedding', **locals())
|
|
|
|
|
remote_prefetch = is_sparse and (not is_distributed)
|
|
|
|
|
remote_prefetch = is_sparse and (not is_distributed) and remote_prefetch
|
|
|
|
|
if remote_prefetch:
|
|
|
|
|
assert is_sparse is True and is_distributed is False
|
|
|
|
|
w = helper.create_parameter(
|
|
|
|
|