!11462 add pipe for cache embedding

From: @fangzehua
Reviewed-by: 
Signed-off-by:
pull/11462/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e02b6852cb

@ -22,7 +22,7 @@
namespace mindspore {
namespace parallel {
// Automatically adding control depend based on effect order and side effect analysis.
void AddCacheEmbedding(const FuncGraphPtr &graph);
void AddCacheEmbedding(const FuncGraphPtr &graph, bool is_pipe = false);
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_

@ -21,7 +21,7 @@ from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.communication.management import get_group_size
from mindspore.context import ParallelMode
from mindspore.context import ParallelMode, get_context
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context
from mindspore._checkparam import Rel
@ -278,7 +278,7 @@ class EmbeddingLookup(Cell):
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
+ str(slice_mode))
if self.cache_enable and not enable_ps:
if is_auto_parallel:
if parallel_mode != ParallelMode.STAND_ALONE:
raise ValueError("parallel mode haven't supported cache enable yet.")
self._set_cache_enable()
self.embedding_table.unique = self.forward_unique
@ -288,15 +288,14 @@ class EmbeddingLookup(Cell):
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
def _set_cache_enable(self):
"""EmbeddingLookup cache check for not ps env."""
"""EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
if self.target != 'DEVICE':
logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
"so it will be ignored.")
return
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target.")
if not self.sparse:
logger.warning("The configuration of 'vocab_cache_size' is valid only 'sparse' is true, "
"so it will be ignored.")
return
raise ValueError("The configuration of 'vocab_cache_size' is valid only 'sparse' is true.")
if get_context("device_target") != 'Ascend':
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'ascend'.")
logger.info("EmbeddingLookup cache enable takes effect.")
self.forward_unique = True
self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')

Loading…
Cancel
Save