|
|
|
@ -21,7 +21,7 @@ from mindspore.ops import functional as F
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
|
from mindspore.communication.management import get_group_size
|
|
|
|
|
from mindspore.context import ParallelMode
|
|
|
|
|
from mindspore.context import ParallelMode, get_context
|
|
|
|
|
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
|
|
|
|
|
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context
|
|
|
|
|
from mindspore._checkparam import Rel
|
|
|
|
@ -278,7 +278,7 @@ class EmbeddingLookup(Cell):
|
|
|
|
|
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
|
|
|
|
|
+ str(slice_mode))
|
|
|
|
|
if self.cache_enable and not enable_ps:
|
|
|
|
|
if is_auto_parallel:
|
|
|
|
|
if parallel_mode != ParallelMode.STAND_ALONE:
|
|
|
|
|
raise ValueError("parallel mode haven't supported cache enable yet.")
|
|
|
|
|
self._set_cache_enable()
|
|
|
|
|
self.embedding_table.unique = self.forward_unique
|
|
|
|
@ -288,15 +288,14 @@ class EmbeddingLookup(Cell):
|
|
|
|
|
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
|
|
|
|
|
|
|
|
|
|
def _set_cache_enable(self):
|
|
|
|
|
"""EmbeddingLookup cache check for not ps env."""
|
|
|
|
|
"""EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
|
|
|
|
|
if self.target != 'DEVICE':
|
|
|
|
|
logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
|
|
|
|
|
"so it will be ignored.")
|
|
|
|
|
return
|
|
|
|
|
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target.")
|
|
|
|
|
if not self.sparse:
|
|
|
|
|
logger.warning("The configuration of 'vocab_cache_size' is valid only 'sparse' is true, "
|
|
|
|
|
"so it will be ignored.")
|
|
|
|
|
return
|
|
|
|
|
raise ValueError("The configuration of 'vocab_cache_size' is valid only 'sparse' is true.")
|
|
|
|
|
if get_context("device_target") != 'Ascend':
|
|
|
|
|
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'ascend'.")
|
|
|
|
|
|
|
|
|
|
logger.info("EmbeddingLookup cache enable takes effect.")
|
|
|
|
|
self.forward_unique = True
|
|
|
|
|
self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')
|
|
|
|
|