|
|
@ -14,7 +14,6 @@
|
|
|
|
# ============================================================================
|
|
|
|
# ============================================================================
|
|
|
|
"""embedding"""
|
|
|
|
"""embedding"""
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
import mindspore.context as context
|
|
|
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
from mindspore import log as logger
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
from mindspore.ops import operations as P
|
|
|
@ -23,8 +22,8 @@ from mindspore.common.parameter import Parameter
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
from mindspore.communication.management import get_group_size
|
|
|
|
from mindspore.communication.management import get_group_size
|
|
|
|
from mindspore.context import ParallelMode
|
|
|
|
from mindspore.context import ParallelMode
|
|
|
|
from mindspore.parallel._utils import _get_parallel_mode
|
|
|
|
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
|
|
|
|
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker
|
|
|
|
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context
|
|
|
|
from mindspore._checkparam import Rel
|
|
|
|
from mindspore._checkparam import Rel
|
|
|
|
from mindspore._checkparam import Validator as validator
|
|
|
|
from mindspore._checkparam import Validator as validator
|
|
|
|
from mindspore.ops.primitive import constexpr
|
|
|
|
from mindspore.ops.primitive import constexpr
|
|
|
@ -195,11 +194,6 @@ class EmbeddingLookup(Cell):
|
|
|
|
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
|
|
|
|
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
|
|
|
|
if not sparse and target == 'CPU':
|
|
|
|
if not sparse and target == 'CPU':
|
|
|
|
raise ValueError('When target is CPU, embedding_lookup must be sparse.')
|
|
|
|
raise ValueError('When target is CPU, embedding_lookup must be sparse.')
|
|
|
|
enable_ps = context.get_ps_context("enable_ps")
|
|
|
|
|
|
|
|
if not enable_ps and vocab_cache_size > 0:
|
|
|
|
|
|
|
|
logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning mode, "
|
|
|
|
|
|
|
|
"current mode is not parameter server trainning mode, so it will be ignored.")
|
|
|
|
|
|
|
|
vocab_cache_size = 0
|
|
|
|
|
|
|
|
if sparse:
|
|
|
|
if sparse:
|
|
|
|
self.gatherv2 = P.SparseGatherV2()
|
|
|
|
self.gatherv2 = P.SparseGatherV2()
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -207,22 +201,14 @@ class EmbeddingLookup(Cell):
|
|
|
|
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
|
|
|
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
|
|
|
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
|
|
|
|
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
|
|
|
|
self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
|
|
|
|
self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
|
|
|
|
|
|
|
|
self._process_vocab_cache(slice_mode)
|
|
|
|
self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
|
|
|
|
self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
|
|
|
|
self.cache_enable = self.vocab_cache_size > 0
|
|
|
|
|
|
|
|
if self.cache_enable:
|
|
|
|
|
|
|
|
if is_auto_parallel:
|
|
|
|
|
|
|
|
self.vocab_cache_size = self.vocab_cache_size * get_group_size()
|
|
|
|
|
|
|
|
self.vocab_size = self.vocab_cache_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
|
|
|
|
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
|
|
|
|
name='embedding_table')
|
|
|
|
name='embedding_table')
|
|
|
|
if self.cache_enable:
|
|
|
|
if self.cache_enable:
|
|
|
|
self.embedding_table.cache_enable = True
|
|
|
|
self._set_voacb_cache_enable(vocab_cache_size, embedding_size, vocab_size)
|
|
|
|
_set_cache_enable(True)
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
if _is_role_worker():
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
_insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
|
|
|
|
|
|
|
|
self.forward_unique = False
|
|
|
|
self.forward_unique = False
|
|
|
|
self.gather_revert = P.GatherV2()
|
|
|
|
self.gather_revert = P.GatherV2()
|
|
|
|
self.unique = P.Unique().shard(((1,),))
|
|
|
|
self.unique = P.Unique().shard(((1,),))
|
|
|
@ -241,7 +227,8 @@ class EmbeddingLookup(Cell):
|
|
|
|
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
|
|
|
|
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
|
|
|
|
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
|
|
|
|
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
|
|
|
|
elif slice_mode == "table_row_slice" and is_auto_parallel:
|
|
|
|
elif slice_mode == "table_row_slice" and is_auto_parallel:
|
|
|
|
if target == 'DEVICE' and not self.cache_enable:
|
|
|
|
full_batch = _get_full_batch()
|
|
|
|
|
|
|
|
if target == 'DEVICE' and not full_batch:
|
|
|
|
indices_shape_size = 1
|
|
|
|
indices_shape_size = 1
|
|
|
|
self.gather_revert.shard(((1, 1), (get_group_size(),)))
|
|
|
|
self.gather_revert.shard(((1, 1), (get_group_size(),)))
|
|
|
|
self.forward_unique = True
|
|
|
|
self.forward_unique = True
|
|
|
@ -272,6 +259,39 @@ class EmbeddingLookup(Cell):
|
|
|
|
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
|
|
|
|
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
|
|
|
|
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
|
|
|
|
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _process_vocab_cache(self, slice_mode):
|
|
|
|
|
|
|
|
"""PS embeddingLookup cache check and process."""
|
|
|
|
|
|
|
|
self.cache_enable = False
|
|
|
|
|
|
|
|
if self.vocab_cache_size > 0:
|
|
|
|
|
|
|
|
if self.target == 'CPU':
|
|
|
|
|
|
|
|
logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
|
|
|
|
|
|
|
|
"current target is CPU, so it will be ignored.")
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
enable_ps = _get_ps_context("enable_ps")
|
|
|
|
|
|
|
|
if not enable_ps:
|
|
|
|
|
|
|
|
logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning "
|
|
|
|
|
|
|
|
"mode, current mode is not parameter server trainning mode, so it will be ignored.")
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
|
|
|
|
if is_auto_parallel:
|
|
|
|
|
|
|
|
device_num = get_group_size()
|
|
|
|
|
|
|
|
full_batch = _get_full_batch()
|
|
|
|
|
|
|
|
if device_num > 1 and not (full_batch and slice_mode == TABLE_ROW_SLICE):
|
|
|
|
|
|
|
|
raise ValueError("The embeddingLookup cache of parameter server parallel only be used "
|
|
|
|
|
|
|
|
"in 'full_batch' and 'table_row_slice' parallel strategy.")
|
|
|
|
|
|
|
|
self.vocab_cache_size = self.vocab_cache_size * device_num
|
|
|
|
|
|
|
|
self.cache_enable = True
|
|
|
|
|
|
|
|
self.vocab_size = self.vocab_cache_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size):
|
|
|
|
|
|
|
|
"""PS embeddingLookup cache enable set."""
|
|
|
|
|
|
|
|
self.embedding_table.cache_enable = True
|
|
|
|
|
|
|
|
self.embedding_table.is_param_ps = True
|
|
|
|
|
|
|
|
_set_cache_enable(True)
|
|
|
|
|
|
|
|
if _is_role_worker():
|
|
|
|
|
|
|
|
_insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
|
|
|
|
|
|
|
|
|
|
|
|
def construct(self, indices):
|
|
|
|
def construct(self, indices):
|
|
|
|
if self.target == "CPU":
|
|
|
|
if self.target == "CPU":
|
|
|
|
out = self.embeddinglookup(self.embedding_table, indices, 0)
|
|
|
|
out = self.embeddinglookup(self.embedding_table, indices, 0)
|
|
|
|