!7651 Add embedding and embedding_lookup

Merge pull request !7651 from jiangzhenguang/add_embedding_and_enbedding_lookup
pull/7651/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0d43e37f22

@ -59,7 +59,7 @@ const AnfNodePtr ClipByNormNoDivSquareSumFusion::Process(const FuncGraphPtr &gra
auto prim = std::make_shared<Primitive>(kClipByNormNoDivSumOpName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), input, constant_select, constant_greater, constant_maximum};
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), input, constant_greater, constant_select, constant_maximum};
auto fusion_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(fusion_node);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};

@ -260,6 +260,7 @@ def _is_float_dtype(dtype):
return True
return False
class ClipByNorm(Cell):
r"""
Clips tensor values to a maximum :math:`L_2`-norm.
@ -271,6 +272,9 @@ class ClipByNorm(Cell):
\text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
Args:
axis (Union[None, int, tuple(int)): Compute the L2-norm along the Specific dimension.
Default: None, all dimensions to calculate.
Inputs:
- **input** (Tensor) - Tensor of shape N-D. The type must be float32 or float16.
@ -287,8 +291,14 @@ class ClipByNorm(Cell):
"""
def __init__(self):
def __init__(self, axis=None):
super(ClipByNorm, self).__init__()
if axis is None:
axis = ()
if isinstance(axis, tuple):
for idx, item in enumerate(axis):
Validator.check_value_type("axis[%d]" % idx, item, [int], self.cls_name)
self.axis = Validator.check_value_type('axis', axis, [int, tuple], self.cls_name)
self.reduce_sum = P.ReduceSum(keep_dims=True)
self.select_ = P.Select()
self.greater_ = P.Greater()
@ -305,7 +315,7 @@ class ClipByNorm(Cell):
def construct(self, x, clip_norm):
"""add ms_function decorator for pynative mode"""
mul_x = F.square(x)
l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32)
l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32)
cond = self.greater_(l2sum, 0)
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
@ -318,7 +328,9 @@ class ClipByNorm(Cell):
intermediate = x * clip_norm
max_norm = self.max_op(l2norm, clip_norm)
values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1)
if self.axis is None:
max_norm = self.expand_dims(max_norm, -1)
values_clip = self.cast(intermediate, mstype.float32) / max_norm
values_clip = self.reshape(values_clip, self.shape(x))
values_clip = identity(values_clip)
return values_clip

@ -16,16 +16,21 @@
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.communication.management import get_group_size
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode
from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator
from mindspore.ops.primitive import constexpr
from .basic import ClipByNorm
from ..cell import Cell
__all__ = ['Embedding', 'EmbeddingLookup']
class Embedding(Cell):
r"""
A simple lookup table that stores embeddings of a fixed dictionary and size.
@ -45,7 +50,8 @@ class Embedding(Cell):
Refer to class `initializer` for the values of string when a string
is specified. Default: 'normal'.
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
will be initialized to zero. Default: None. The feature is inactivated.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The elements of
the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
@ -63,16 +69,24 @@ class Embedding(Cell):
>>> output.shape
(8, 128, 768)
"""
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32):
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
dtype=mstype.float32, padding_idx=None):
super(Embedding, self).__init__()
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
self.vocab_size = vocab_size
self.embedding_size = embedding_size
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
self.use_one_hot = use_one_hot
self.embedding_table = Parameter(initializer(embedding_table, [vocab_size, embedding_size]),
name='embedding_table')
self.dtype = dtype
self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
self.padding_idx = padding_idx
if padding_idx is not None:
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
"padding_idx", self.cls_name)
self.init_tensor = self.init_tensor.to_tensor().asnumpy()
self.init_tensor[self.padding_idx] = 0
self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
self.expand = P.ExpandDims()
self.reshape_flat = P.Reshape()
self.shp_flat = (-1,)
@ -99,16 +113,17 @@ class Embedding(Cell):
return output
def extend_repr(self):
s = 'vocab_size={}, embedding_size={},' \
'use_one_hot={}, ' \
'embedding_table={}, dtype={}'.format(
self.vocab_size,
self.embedding_size,
self.use_one_hot,
self.embedding_table,
self.dtype)
s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
return s
@constexpr
def _make_axis_range(start, end):
axis = tuple(range(start, end))
return axis
class EmbeddingLookup(Cell):
r"""
Returns a slice of input tensor based on the specified indices.
@ -120,8 +135,7 @@ class EmbeddingLookup(Cell):
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
specified 'axis = 0' to lookup table.
In field slice mode, the manual_shapes must be given. It is a tuple ,where
the element is vocab[i], vocab[i] is the row numbers for i-th
part.
the element is vocab[i], vocab[i] is the row numbers for i-th part.
Args:
vocab_size (int): Size of the dictionary of embeddings.
@ -132,6 +146,8 @@ class EmbeddingLookup(Cell):
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
manual_shapes (tuple): The accompaniment array in field slice mode.
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
or None. Default: None
Inputs:
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
@ -152,7 +168,7 @@ class EmbeddingLookup(Cell):
TABLE_COLUMN_SLICE = "table_column_slice"
def __init__(self, vocab_size, embedding_size, param_init='normal',
target='CPU', slice_mode='batch_slice', manual_shapes=None):
target='CPU', slice_mode='batch_slice', manual_shapes=None, max_norm=None):
super(EmbeddingLookup, self).__init__()
self.target = target
if target not in ('CPU', 'DEVICE'):
@ -160,7 +176,9 @@ class EmbeddingLookup(Cell):
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
self.gatherv2 = P.GatherV2()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
self.embedding_table = Parameter(initializer(param_init, [vocab_size, embedding_size]),
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
name='embedding_table')
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
@ -188,10 +206,18 @@ class EmbeddingLookup(Cell):
if is_auto_parallel:
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
+ str(slice_mode))
self.max_norm = max_norm
if self.max_norm is not None:
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
def construct(self, indices):
if self.target == "CPU":
out = self.embeddinglookup(self.embedding_table, indices, 0)
else:
out = self.gatherv2(self.embedding_table, indices, 0)
if self.max_norm is not None:
axis = _make_axis_range(F.rank(indices), F.rank(out))
clip_by_norm = ClipByNorm(axis)
out = clip_by_norm(out, self.max_norm)
return out

@ -228,6 +228,44 @@ class Moments(nn.Cell):
return mean, variance
class ClipByNorm(nn.Cell):
"""ClipByNorm net definition"""
def __init__(self, axis=None):
super(ClipByNorm, self).__init__()
self.clip_by_norm = nn.ClipByNorm(axis=axis)
def construct(self, input_x, max_norm):
norm = self.clip_by_norm(input_x, max_norm)
return norm
class Embedding(nn.Cell):
"""Embedding net definition"""
def __init__(self, vocab_size, embedding_size, padding_idx=None):
super(Embedding, self).__init__()
self.embedding = nn.Embedding(vocab_size=vocab_size, embedding_size=embedding_size,
padding_idx=padding_idx)
def construct(self, index):
res = self.embedding(index)
return res
class EmbeddingLookup(nn.Cell):
"""EmbeddingLookup net definition"""
def __init__(self, vocab_size, embedding_size, max_norm=None):
super(EmbeddingLookup, self).__init__()
self.embedding_lookup = nn.EmbeddingLookup(vocab_size=vocab_size, embedding_size=embedding_size,
max_norm=max_norm)
def construct(self, index):
res = self.embedding_lookup(index)
return res
class CountNonZero(nn.Cell):
"""CountNonZero net definition"""
@ -1082,6 +1120,32 @@ test_case_math_ops = [
'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])),
[2, 3], [2, 3]],
'desc_bprop': [[2, 3]]}),
('ClipByNorm_1', {
'block': ClipByNorm(),
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)),
Tensor(np.array([0.01]).astype(np.float32))],
'skip': ['backward']}),
('ClipByNorm_2', {
'block': ClipByNorm(axis=0),
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)),
Tensor(np.array([0.01]).astype(np.float32))],
'skip': ['backward']}),
('Embedding_1', {
'block': Embedding(vocab_size=10, embedding_size=3),
'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))],
'skip': ['backward']}),
('Embedding_2', {
'block': Embedding(vocab_size=10, embedding_size=3, padding_idx=2),
'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))],
'skip': ['backward']}),
('EmbeddingLookup_1', {
'block': EmbeddingLookup(vocab_size=10, embedding_size=3),
'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))],
'skip': ['backward']}),
('EmbeddingLookup_2', {
'block': EmbeddingLookup(vocab_size=10, embedding_size=3, max_norm=0.01),
'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))],
'skip': ['backward']}),
('Moments', {
'block': Moments(axis=(), keep_dims=False),
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))],

Loading…
Cancel
Save