|
|
@ -189,6 +189,7 @@ class EmbeddingLookup(Cell):
|
|
|
|
target='CPU', slice_mode='batch_slice', manual_shapes=None,
|
|
|
|
target='CPU', slice_mode='batch_slice', manual_shapes=None,
|
|
|
|
max_norm=None, sparse=True, vocab_cache_size=0):
|
|
|
|
max_norm=None, sparse=True, vocab_cache_size=0):
|
|
|
|
super(EmbeddingLookup, self).__init__()
|
|
|
|
super(EmbeddingLookup, self).__init__()
|
|
|
|
|
|
|
|
validator.check_value_type('sparse', sparse, [bool], self.cls_name)
|
|
|
|
self.target = target
|
|
|
|
self.target = target
|
|
|
|
if target not in ('CPU', 'DEVICE'):
|
|
|
|
if target not in ('CPU', 'DEVICE'):
|
|
|
|
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
|
|
|
|
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
|
|
|
@ -200,9 +201,9 @@ class EmbeddingLookup(Cell):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.gatherv2 = P.GatherV2()
|
|
|
|
self.gatherv2 = P.GatherV2()
|
|
|
|
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
|
|
|
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
|
|
|
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
|
|
|
|
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
|
|
|
|
self.vocab_cache_size = validator.check_value_type('vocab_cache_size', vocab_cache_size, [int], self.cls_name)
|
|
|
|
self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
|
|
|
|
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
|
|
|
|
self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
self.cache_enable = self.vocab_cache_size > 0
|
|
|
|
self.cache_enable = self.vocab_cache_size > 0
|
|
|
@ -355,7 +356,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
|
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
|
|
|
|
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
|
|
|
|
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
|
|
|
|
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
|
|
|
|
slice_mode, feature_num_list, max_norm, sparse)
|
|
|
|
slice_mode, feature_num_list, max_norm, sparse)
|
|
|
|
self.field_size = validator.check_value_type('field_size', field_size, [int], self.cls_name)
|
|
|
|
self.field_size = validator.check_positive_int(field_size, 'field_size')
|
|
|
|
self.operator = operator
|
|
|
|
self.operator = operator
|
|
|
|
|
|
|
|
|
|
|
|
self.mul = P.Mul()
|
|
|
|
self.mul = P.Mul()
|
|
|
@ -429,7 +430,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
|
batch_size = self.shape(input_indices)[0]
|
|
|
|
batch_size = self.shape(input_indices)[0]
|
|
|
|
num_segments = batch_size * self.field_size
|
|
|
|
num_segments = batch_size * self.field_size
|
|
|
|
bias = Range(0, num_segments, self.field_size)()
|
|
|
|
bias = Range(0, num_segments, self.field_size)()
|
|
|
|
bias = self.reshape(bias, (self.field_size, -1))
|
|
|
|
bias = self.reshape(bias, (batch_size, -1))
|
|
|
|
field_ids = self.bias_add(field_ids, bias)
|
|
|
|
field_ids = self.bias_add(field_ids, bias)
|
|
|
|
|
|
|
|
|
|
|
|
if self.target == "CPU":
|
|
|
|
if self.target == "CPU":
|
|
|
|