You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
461 lines
23 KiB
461 lines
23 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""embedding"""
|
|
import mindspore.common.dtype as mstype
|
|
from mindspore.common.tensor import Tensor
|
|
from mindspore.ops import operations as P
|
|
from mindspore.ops import functional as F
|
|
from mindspore.common.parameter import Parameter
|
|
from mindspore.common.initializer import initializer
|
|
from mindspore.communication.management import get_group_size
|
|
from mindspore.context import ParallelMode
|
|
from mindspore.parallel._utils import _get_parallel_mode
|
|
from mindspore._checkparam import Rel
|
|
from mindspore._checkparam import Validator as validator
|
|
from mindspore.ops.primitive import constexpr
|
|
from .basic import ClipByNorm
|
|
from .math import Range
|
|
from ..cell import Cell
|
|
|
|
__all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup']
|
|
|
|
|
|
class Embedding(Cell):
|
|
r"""
|
|
A simple lookup table that stores embeddings of a fixed dictionary and size.
|
|
|
|
This module is often used to store word embeddings and retrieve them using
|
|
indices. The input to the module is a list of indices, and the output is
|
|
the corresponding word embeddings.
|
|
|
|
Note:
|
|
When 'use_one_hot' is set to True, the type of the input must be mindspore.int32.
|
|
|
|
Args:
|
|
vocab_size (int): Size of the dictionary of embeddings.
|
|
embedding_size (int): The size of each embedding vector.
|
|
use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: False.
|
|
embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
|
|
Refer to class `initializer` for the values of string when a string
|
|
is specified. Default: 'normal'.
|
|
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
|
|
padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
|
|
will be initialized to zero. Default: None. The feature is inactivated.
|
|
Inputs:
|
|
- **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The elements of
|
|
the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
|
|
be zero.
|
|
|
|
Outputs:
|
|
Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`.
|
|
|
|
Supported Platforms:
|
|
``Ascend`` ``GPU``
|
|
|
|
Examples:
|
|
>>> net = nn.Embedding(20000, 768, True)
|
|
>>> input_data = Tensor(np.ones([8, 128]), mindspore.int32)
|
|
>>>
|
|
>>> # Maps the input word IDs to word embedding.
|
|
>>> output = net(input_data)
|
|
>>> result = output.shape
|
|
>>> print(result)
|
|
(8, 128, 768)
|
|
"""
|
|
|
|
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
|
|
dtype=mstype.float32, padding_idx=None):
|
|
super(Embedding, self).__init__()
|
|
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
|
|
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
|
|
validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
|
|
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
|
self.use_one_hot = use_one_hot
|
|
self.dtype = dtype
|
|
self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
|
|
self.padding_idx = padding_idx
|
|
if padding_idx is not None:
|
|
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
|
|
"padding_idx", self.cls_name)
|
|
self.init_tensor = self.init_tensor.to_tensor().asnumpy()
|
|
self.init_tensor[self.padding_idx] = 0
|
|
self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
|
|
self.expand = P.ExpandDims()
|
|
self.reshape_flat = P.Reshape()
|
|
self.shp_flat = (-1,)
|
|
self.gather = P.GatherV2()
|
|
self.one_hot = P.OneHot()
|
|
self.on_value = Tensor(1.0, self.dtype)
|
|
self.off_value = Tensor(0.0, self.dtype)
|
|
self.array_mul = P.MatMul()
|
|
self.reshape = P.Reshape()
|
|
self.get_shp = P.Shape()
|
|
|
|
def construct(self, ids):
|
|
extended_ids = self.expand(ids, -1)
|
|
out_shape = self.get_shp(ids) + (self.embedding_size,)
|
|
flat_ids = self.reshape_flat(extended_ids, self.shp_flat)
|
|
|
|
if self.use_one_hot:
|
|
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
|
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
|
|
else:
|
|
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
|
|
|
output = self.reshape(output_for_reshape, out_shape)
|
|
return output
|
|
|
|
def extend_repr(self):
|
|
s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
|
|
self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
|
|
return s
|
|
|
|
|
|
@constexpr
|
|
def _make_axis_range(start, end):
|
|
axis = tuple(range(start, end))
|
|
return axis
|
|
|
|
|
|
class EmbeddingLookup(Cell):
|
|
r"""
|
|
Returns a slice of the input tensor based on the specified indices.
|
|
|
|
Note:
|
|
When 'target' is set to 'CPU', this module will use
|
|
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
|
|
specified 'offset = 0' to lookup table.
|
|
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
|
|
specified 'axis = 0' to lookup table.
|
|
In field slice mode, the manual_shapes must be given. It is a tuple ,where
|
|
the element is vocab[i], vocab[i] is the row numbers for i-th part.
|
|
|
|
Args:
|
|
vocab_size (int): Size of the dictionary of embeddings.
|
|
embedding_size (int): The size of each embedding vector.
|
|
param_init (str): The initialize way of embedding table. Default: 'normal'.
|
|
target (str): Specifies the target where the op is executed. The value must in
|
|
['DEVICE', 'CPU']. Default: 'CPU'.
|
|
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
|
|
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
|
|
manual_shapes (tuple): The accompaniment array in field slice mode.
|
|
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
|
or None. Default: None
|
|
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
|
|
|
Inputs:
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
|
|
and the exceeding part will be filled with 0 in the output. Input_indices must only be a 2d tensor in
|
|
this interface when run in semi auto parallel/auto parallel mode.
|
|
|
|
Outputs:
|
|
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
|
|
|
Supported Platforms:
|
|
``Ascend`` ``GPU`` ``CPU``
|
|
|
|
Examples:
|
|
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
|
|
>>> result = nn.EmbeddingLookup(4,2)(input_indices)
|
|
>>> print(result)
|
|
[[[ 0.00856617 0.01039034]
|
|
[ 0.00196276 -0.00094072]]
|
|
[[ 0.01279703 0.00078912]
|
|
[ 0.00084863 -0.00742412]]]
|
|
"""
|
|
BATCH_SLICE = "batch_slice"
|
|
FIELD_SLICE = "field_slice"
|
|
TABLE_ROW_SLICE = "table_row_slice"
|
|
TABLE_COLUMN_SLICE = "table_column_slice"
|
|
|
|
def __init__(self, vocab_size, embedding_size, param_init='normal',
|
|
target='CPU', slice_mode='batch_slice', manual_shapes=None,
|
|
max_norm=None, sparse=True):
|
|
super(EmbeddingLookup, self).__init__()
|
|
self.target = target
|
|
if target not in ('CPU', 'DEVICE'):
|
|
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
|
|
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
|
|
if not sparse and target == 'CPU':
|
|
raise ValueError('When target is CPU, embedding_lookup must be sparse.')
|
|
if sparse:
|
|
self.gatherv2 = P.SparseGatherV2()
|
|
else:
|
|
self.gatherv2 = P.GatherV2()
|
|
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
|
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
|
|
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
|
|
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
|
|
name='embedding_table')
|
|
parallel_mode = _get_parallel_mode()
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
self.forward_unique = False
|
|
self.gather_revert = P.GatherV2()
|
|
self.unique = P.Unique().shard(((1,),))
|
|
self.reshape = P.Reshape()
|
|
self.shape = P.Shape()
|
|
indices_shape_size = 2
|
|
if slice_mode == "field_slice" and is_auto_parallel:
|
|
if not manual_shapes:
|
|
raise ValueError("in slice field mode, the manual_shapes should not be none")
|
|
if not isinstance(manual_shapes, tuple):
|
|
raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes)))
|
|
for dim in manual_shapes:
|
|
validator.check_positive_int(dim, 'manual shape dim', self.cls_name)
|
|
self.gatherv2.add_prim_attr("manual_split", manual_shapes)
|
|
self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
|
|
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
|
|
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
|
|
elif slice_mode == "table_row_slice" and is_auto_parallel:
|
|
if target == 'DEVICE':
|
|
indices_shape_size = 1
|
|
self.gather_revert.shard(((1, 1), (1,)))
|
|
self.forward_unique = True
|
|
indices_strategy = (1,)*indices_shape_size
|
|
self.gatherv2.shard(((get_group_size(), 1), indices_strategy))
|
|
self.embeddinglookup.shard(((get_group_size(), 1), indices_strategy))
|
|
elif slice_mode == "table_column_slice" and is_auto_parallel:
|
|
if target == 'DEVICE':
|
|
indices_shape_size = 1
|
|
self.gather_revert.shard(((1, get_group_size()), (1,)))
|
|
self.forward_unique = True
|
|
indices_strategy = (1,)*indices_shape_size
|
|
self.gatherv2.shard(((1, get_group_size()), indices_strategy))
|
|
self.embeddinglookup.shard(((1, get_group_size()), indices_strategy))
|
|
elif slice_mode == "batch_slice" and is_auto_parallel:
|
|
indices_strategy = [get_group_size()]
|
|
indices_strategy.extend([1]*(indices_shape_size - 1))
|
|
indices_strategy = tuple(indices_strategy)
|
|
self.gatherv2.shard(((1, 1), indices_strategy))
|
|
self.embeddinglookup.shard(((1, 1), indices_strategy))
|
|
else:
|
|
if is_auto_parallel:
|
|
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
|
|
+ str(slice_mode))
|
|
self.embedding_table.unique = self.forward_unique
|
|
self.max_norm = max_norm
|
|
if self.max_norm is not None:
|
|
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
|
|
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
|
|
|
|
def construct(self, indices):
|
|
if self.target == "CPU":
|
|
out = self.embeddinglookup(self.embedding_table, indices, 0)
|
|
else:
|
|
if self.forward_unique:
|
|
shp = self.shape(indices) + (self.embedding_size,)
|
|
indices_flatten = self.reshape(indices, (-1,))
|
|
unique_id, unique_idx = self.unique(indices_flatten)
|
|
weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
|
|
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
|
|
out = self.reshape(weight_flatten, shp)
|
|
else:
|
|
out = self.gatherv2(self.embedding_table, indices, 0)
|
|
if self.max_norm is not None:
|
|
axis = _make_axis_range(F.rank(indices), F.rank(out))
|
|
clip_by_norm = ClipByNorm(axis)
|
|
out = clip_by_norm(out, self.max_norm)
|
|
return out
|
|
|
|
|
|
class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
r"""
|
|
Returns a slice of input tensor based on the specified indices based on the field ids. This operation
|
|
supports looking up embeddings within multi hot and one hot fields simultaneously.
|
|
|
|
Note:
|
|
When 'target' is set to 'CPU', this module will use
|
|
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
|
|
specified 'offset = 0' to lookup table.
|
|
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
|
|
specified 'axis = 0' to lookup table.
|
|
The vectors with the same field_ids will be combined by the `operator`, such as `SUM`, `MAX` and
|
|
`MEAN`. Ensure the input_values of the padded id is zero, so that they can be ignored. The final
|
|
output will be zeros if the sum of absolute weight of the field is zero. This class only
|
|
supports ['table_row_slice', 'batch_slice' and 'table_column_slice']
|
|
|
|
Args:
|
|
vocab_size (int): Size of the dictionary of embeddings.
|
|
embedding_size (int): The size of each embedding vector.
|
|
field_size (int): The field size of the final outputs.
|
|
param_init (str): The initialize way of embedding table. Default: 'normal'.
|
|
target (str): Specifies the target where the op is executed. The value must in
|
|
['DEVICE', 'CPU']. Default: 'CPU'.
|
|
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
|
|
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
|
|
feature_num_list (tuple): The accompaniment array in field slice mode.
|
|
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
|
or None. Default: None
|
|
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
|
operator (string): The pooling method for the features in one field. Support 'SUM, 'MEAN' and 'MAX'
|
|
|
|
Inputs:
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`.
|
|
Specifies the indices of elements of the original Tensor. Input_indices must be a 2d tensor in
|
|
this interface. Type is Int16, Int32, Int64.
|
|
- **input_values** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`.
|
|
Specifies the weights of elements of the input_indices. The lookout vector will multiply with
|
|
the input_values. Type is Float32.
|
|
- **field_ids** (Tensor) - The shape of tensor is :math:`(batch_size, seq_length)`.
|
|
Specifies the field id of elements of the input_indices. Type is Int16, Int32.
|
|
|
|
Outputs:
|
|
Tensor, the shape of tensor is :math:`(batch_size, field_size, embedding_size)`. Type is Float32.
|
|
|
|
Supported Platforms:
|
|
``Ascend`` ``GPU``
|
|
|
|
Examples:
|
|
>>> input_indices = Tensor([[2, 4, 6, 0, 0], [1, 3, 5, 0, 0]], mindspore.int32)
|
|
>>> input_values = Tensor([[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], mindspore.float32)
|
|
>>> field_ids = Tensor([[0, 1, 1, 0, 0], [0, 0, 1, 0, 0]], mindspore.int32)
|
|
>>> net = nn.MultiFieldEmbeddingLookup(10, 2, field_size=2, operator='SUM')
|
|
>>> out = net(input_indices, input_values, field_ids)
|
|
>>> print(out)
|
|
[[[-0.00478983 -0.00772568]
|
|
[-0.00968955 -0.00064902]]
|
|
[[-0.01251151 -0.01251151]
|
|
[-0.00196387 -0.00196387]
|
|
"""
|
|
OPERATOR_SUM = 'SUM'
|
|
OPERATOR_MEAN = 'MEAN'
|
|
OPERATOR_MAX = 'MAX'
|
|
def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
|
|
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
|
|
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
|
|
slice_mode, feature_num_list, max_norm, sparse)
|
|
self.field_size = validator.check_value_type('field_size', field_size, [int], self.cls_name)
|
|
self.operator = operator
|
|
|
|
self.mul = P.Mul()
|
|
self.inf_mask_mul = P.Mul()
|
|
self.bias_add = P.TensorAdd()
|
|
self.inf_add = P.TensorAdd()
|
|
self.merge_op = None
|
|
self.count_op = P.UnsortedSegmentSum()
|
|
self.abs = P.Abs()
|
|
self.equal = P.Equal()
|
|
self.add = P.TensorAdd()
|
|
self.cast = P.Cast()
|
|
self.div_no_nan = P.DivNoNan()
|
|
self.expand = P.ExpandDims()
|
|
self.max_mask_mul = P.Mul()
|
|
self.max_no_equal = P.NotEqual()
|
|
|
|
if operator == MultiFieldEmbeddingLookup.OPERATOR_SUM:
|
|
self.merge_op = P.UnsortedSegmentSum()
|
|
elif operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
|
|
self.merge_op = P.UnsortedSegmentMax()
|
|
elif operator == MultiFieldEmbeddingLookup.OPERATOR_MEAN:
|
|
self.merge_op = P.UnsortedSegmentSum()
|
|
else:
|
|
raise ValueError("The operator supports ['SUM', 'MAX', 'MEAN'], but found: "+str(operator))
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
if slice_mode in ["table_row_slice", "batch_slice"] and is_auto_parallel:
|
|
self.merge_op.shard(((get_group_size(), 1, 1), (get_group_size(), 1)))
|
|
self.expand.shard(((get_group_size(),),))
|
|
self.bias_add.shard(((1, 1), (1, 1)))
|
|
self.mul.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1)))
|
|
self.count_op.shard(((get_group_size(), 1), (get_group_size(), 1)))
|
|
self.add.shard(((get_group_size(),), (get_group_size(),)))
|
|
self.div_no_nan.shard(((get_group_size(), 1), (get_group_size(), 1)))
|
|
self.max_mask_mul.shard(((get_group_size(), 1), (get_group_size(), 1)))
|
|
self.max_no_equal.shard(((1,), ()))
|
|
if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
|
|
self.equal.shard(((get_group_size(), 1, 1), ()))
|
|
self.inf_mask_mul.shard(((get_group_size(), 1, 1), ()))
|
|
self.merge_op.shard(((get_group_size(), 1), (get_group_size(),)))
|
|
self.count_op.shard(((get_group_size(),), (get_group_size(),)))
|
|
self.inf_add.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1)))
|
|
elif slice_mode == "table_column_slice" and is_auto_parallel:
|
|
self.merge_op.shard(((1, 1, get_group_size()), (1, 1)))
|
|
self.div_no_nan.shard(((1, get_group_size()), (1, 1)))
|
|
self.bias_add.shard(((1, 1), (1, 1)))
|
|
self.mul.shard(((1, 1, 1), (1, 1, get_group_size())))
|
|
self.count_op.shard(((1, 1), (1, 1)))
|
|
self.add.shard(((1,), (1,)))
|
|
self.max_mask_mul.shard(((1, get_group_size()), (1, 1)))
|
|
self.expand.shard(((1,),))
|
|
self.max_no_equal.shard(((1,), ()))
|
|
if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
|
|
self.equal.shard(((1, 1, 1), ()))
|
|
self.inf_mask_mul.shard(((1, 1, 1), ()))
|
|
self.merge_op.shard(((1, get_group_size()), (1,)))
|
|
self.count_op.shard(((1,), (1,)))
|
|
self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
|
|
else:
|
|
if is_auto_parallel:
|
|
raise ValueError("slice_mode should be ['table_row_slice', 'batch_slice' and \
|
|
'table_column_slice'], but get " + str(slice_mode))
|
|
|
|
# Min value for fp32
|
|
self.negative_inf_value = -3.402823466E+38
|
|
|
|
def construct(self, input_indices, input_values, field_ids):
|
|
|
|
batch_size = self.shape(input_indices)[0]
|
|
num_segments = batch_size * self.field_size
|
|
|
|
bias = Range(0, num_segments, self.field_size)()
|
|
bias = self.reshape(bias, (self.field_size, -1))
|
|
field_ids = self.bias_add(field_ids, bias)
|
|
|
|
if self.target == "CPU":
|
|
out = self.embeddinglookup(self.embedding_table, input_indices, 0)
|
|
else:
|
|
if self.forward_unique:
|
|
shp = self.shape(input_indices) + (self.embedding_size,)
|
|
indices_flatten = self.reshape(input_indices, (-1,))
|
|
unique_id, unique_idx = self.unique(indices_flatten)
|
|
weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
|
|
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
|
|
out = self.reshape(weight_flatten, shp)
|
|
else:
|
|
out = self.gatherv2(self.embedding_table, input_indices, 0)
|
|
if self.max_norm is not None:
|
|
axis = _make_axis_range(F.rank(input_indices), F.rank(out))
|
|
clip_by_norm = ClipByNorm(axis)
|
|
out = clip_by_norm(out, self.max_norm)
|
|
|
|
weights = self.reshape(input_values, (batch_size, self.shape(input_indices)[1], 1))
|
|
embedding = self.mul(weights, out)
|
|
|
|
if self.operator == 'MAX':
|
|
# Fill the padding value to -inf, so the padded value will not influence the results
|
|
negatvie_inf_mask = self.cast(self.equal(weights, 0), mstype.float32)
|
|
inf_mask = self.inf_mask_mul(negatvie_inf_mask, self.negative_inf_value)
|
|
embedding = self.inf_add(embedding, inf_mask)
|
|
embedding = self.reshape(embedding, (-1, self.embedding_size))
|
|
field_ids = self.reshape(field_ids, (-1,))
|
|
|
|
merged_vectors = self.merge_op(embedding, field_ids, num_segments)
|
|
|
|
if self.operator == 'MAX':
|
|
value_count = self.count_op(self.abs(self.reshape(input_values, (-1,))), field_ids, num_segments)
|
|
value_zeros = self.cast(self.max_no_equal(value_count, 0.0), mstype.float32)
|
|
count = self.expand(value_zeros, -1)
|
|
merged_vectors = self.max_mask_mul(merged_vectors, count)
|
|
|
|
if self.operator == 'MEAN':
|
|
value_count = self.count_op(self.abs(input_values), field_ids, num_segments)
|
|
value_count = self.expand(value_count, -1)
|
|
merged_vectors = self.div_no_nan(merged_vectors, value_count)
|
|
|
|
merged_vectors = self.reshape(merged_vectors, (batch_size, self.field_size, -1))
|
|
return merged_vectors
|