add embedding layer.

pull/7651/head
jzg 4 years ago
parent 7c98803ad6
commit 7cbd55e17d

@ -59,7 +59,7 @@ const AnfNodePtr ClipByNormNoDivSquareSumFusion::Process(const FuncGraphPtr &gra
auto prim = std::make_shared<Primitive>(kClipByNormNoDivSumOpName); auto prim = std::make_shared<Primitive>(kClipByNormNoDivSumOpName);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), input, constant_select, constant_greater, constant_maximum}; std::vector<AnfNodePtr> inputs = {NewValueNode(prim), input, constant_greater, constant_select, constant_maximum};
auto fusion_node = graph->NewCNode(inputs); auto fusion_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(fusion_node); MS_EXCEPTION_IF_NULL(fusion_node);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};

@ -260,6 +260,7 @@ def _is_float_dtype(dtype):
return True return True
return False return False
class ClipByNorm(Cell): class ClipByNorm(Cell):
r""" r"""
Clips tensor values to a maximum :math:`L_2`-norm. Clips tensor values to a maximum :math:`L_2`-norm.
@ -271,6 +272,9 @@ class ClipByNorm(Cell):
\text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)}, \text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`. where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
Args:
axis (Union[None, int, tuple(int)): Compute the L2-norm along the Specific dimension.
Default: None, all dimensions to calculate.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape N-D. The type must be float32 or float16. - **input** (Tensor) - Tensor of shape N-D. The type must be float32 or float16.
@ -287,8 +291,14 @@ class ClipByNorm(Cell):
""" """
def __init__(self): def __init__(self, axis=None):
super(ClipByNorm, self).__init__() super(ClipByNorm, self).__init__()
if axis is None:
axis = ()
if isinstance(axis, tuple):
for idx, item in enumerate(axis):
Validator.check_value_type("axis[%d]" % idx, item, [int], self.cls_name)
self.axis = Validator.check_value_type('axis', axis, [int, tuple], self.cls_name)
self.reduce_sum = P.ReduceSum(keep_dims=True) self.reduce_sum = P.ReduceSum(keep_dims=True)
self.select_ = P.Select() self.select_ = P.Select()
self.greater_ = P.Greater() self.greater_ = P.Greater()
@ -305,7 +315,7 @@ class ClipByNorm(Cell):
def construct(self, x, clip_norm): def construct(self, x, clip_norm):
"""add ms_function decorator for pynative mode""" """add ms_function decorator for pynative mode"""
mul_x = F.square(x) mul_x = F.square(x)
l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32) l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32)
cond = self.greater_(l2sum, 0) cond = self.greater_(l2sum, 0)
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
@ -318,7 +328,9 @@ class ClipByNorm(Cell):
intermediate = x * clip_norm intermediate = x * clip_norm
max_norm = self.max_op(l2norm, clip_norm) max_norm = self.max_op(l2norm, clip_norm)
values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1) if self.axis is None:
max_norm = self.expand_dims(max_norm, -1)
values_clip = self.cast(intermediate, mstype.float32) / max_norm
values_clip = self.reshape(values_clip, self.shape(x)) values_clip = self.reshape(values_clip, self.shape(x))
values_clip = identity(values_clip) values_clip = identity(values_clip)
return values_clip return values_clip

@ -16,16 +16,21 @@
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode from mindspore.parallel._utils import _get_parallel_mode
from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore.ops.primitive import constexpr
from .basic import ClipByNorm
from ..cell import Cell from ..cell import Cell
__all__ = ['Embedding', 'EmbeddingLookup'] __all__ = ['Embedding', 'EmbeddingLookup']
class Embedding(Cell): class Embedding(Cell):
r""" r"""
A simple lookup table that stores embeddings of a fixed dictionary and size. A simple lookup table that stores embeddings of a fixed dictionary and size.
@ -45,7 +50,8 @@ class Embedding(Cell):
Refer to class `initializer` for the values of string when a string Refer to class `initializer` for the values of string when a string
is specified. Default: 'normal'. is specified. Default: 'normal'.
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32. dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
will be initialized to zero. Default: None. The feature is inactivated.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The elements of - **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The elements of
the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
@ -63,16 +69,24 @@ class Embedding(Cell):
>>> output.shape >>> output.shape
(8, 128, 768) (8, 128, 768)
""" """
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32):
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
dtype=mstype.float32, padding_idx=None):
super(Embedding, self).__init__() super(Embedding, self).__init__()
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name) validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
self.vocab_size = vocab_size validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
self.embedding_size = embedding_size
self.use_one_hot = use_one_hot self.use_one_hot = use_one_hot
self.embedding_table = Parameter(initializer(embedding_table, [vocab_size, embedding_size]),
name='embedding_table')
self.dtype = dtype self.dtype = dtype
self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
self.padding_idx = padding_idx
if padding_idx is not None:
self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
"padding_idx", self.cls_name)
self.init_tensor = self.init_tensor.to_tensor().asnumpy()
self.init_tensor[self.padding_idx] = 0
self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
self.expand = P.ExpandDims() self.expand = P.ExpandDims()
self.reshape_flat = P.Reshape() self.reshape_flat = P.Reshape()
self.shp_flat = (-1,) self.shp_flat = (-1,)
@ -99,16 +113,17 @@ class Embedding(Cell):
return output return output
def extend_repr(self): def extend_repr(self):
s = 'vocab_size={}, embedding_size={},' \ s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
'use_one_hot={}, ' \ self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
'embedding_table={}, dtype={}'.format(
self.vocab_size,
self.embedding_size,
self.use_one_hot,
self.embedding_table,
self.dtype)
return s return s
@constexpr
def _make_axis_range(start, end):
axis = tuple(range(start, end))
return axis
class EmbeddingLookup(Cell): class EmbeddingLookup(Cell):
r""" r"""
Returns a slice of input tensor based on the specified indices. Returns a slice of input tensor based on the specified indices.
@ -120,8 +135,7 @@ class EmbeddingLookup(Cell):
When 'target' is set to 'DEVICE', this module will use P.GatherV2() which When 'target' is set to 'DEVICE', this module will use P.GatherV2() which
specified 'axis = 0' to lookup table. specified 'axis = 0' to lookup table.
In field slice mode, the manual_shapes must be given. It is a tuple ,where In field slice mode, the manual_shapes must be given. It is a tuple ,where
the element is vocab[i], vocab[i] is the row numbers for i-th the element is vocab[i], vocab[i] is the row numbers for i-th part.
part.
Args: Args:
vocab_size (int): Size of the dictionary of embeddings. vocab_size (int): Size of the dictionary of embeddings.
@ -132,6 +146,8 @@ class EmbeddingLookup(Cell):
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE. nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
manual_shapes (tuple): The accompaniment array in field slice mode. manual_shapes (tuple): The accompaniment array in field slice mode.
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
or None. Default: None
Inputs: Inputs:
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
@ -152,7 +168,7 @@ class EmbeddingLookup(Cell):
TABLE_COLUMN_SLICE = "table_column_slice" TABLE_COLUMN_SLICE = "table_column_slice"
def __init__(self, vocab_size, embedding_size, param_init='normal', def __init__(self, vocab_size, embedding_size, param_init='normal',
target='CPU', slice_mode='batch_slice', manual_shapes=None): target='CPU', slice_mode='batch_slice', manual_shapes=None, max_norm=None):
super(EmbeddingLookup, self).__init__() super(EmbeddingLookup, self).__init__()
self.target = target self.target = target
if target not in ('CPU', 'DEVICE'): if target not in ('CPU', 'DEVICE'):
@ -160,7 +176,9 @@ class EmbeddingLookup(Cell):
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
self.gatherv2 = P.GatherV2() self.gatherv2 = P.GatherV2()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
self.embedding_table = Parameter(initializer(param_init, [vocab_size, embedding_size]), self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
name='embedding_table') name='embedding_table')
parallel_mode = _get_parallel_mode() parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
@ -188,10 +206,18 @@ class EmbeddingLookup(Cell):
if is_auto_parallel: if is_auto_parallel:
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get " raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
+ str(slice_mode)) + str(slice_mode))
self.max_norm = max_norm
if self.max_norm is not None:
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
def construct(self, indices): def construct(self, indices):
if self.target == "CPU": if self.target == "CPU":
out = self.embeddinglookup(self.embedding_table, indices, 0) out = self.embeddinglookup(self.embedding_table, indices, 0)
else: else:
out = self.gatherv2(self.embedding_table, indices, 0) out = self.gatherv2(self.embedding_table, indices, 0)
if self.max_norm is not None:
axis = _make_axis_range(F.rank(indices), F.rank(out))
clip_by_norm = ClipByNorm(axis)
out = clip_by_norm(out, self.max_norm)
return out return out

@ -228,6 +228,44 @@ class Moments(nn.Cell):
return mean, variance return mean, variance
class ClipByNorm(nn.Cell):
"""ClipByNorm net definition"""
def __init__(self, axis=None):
super(ClipByNorm, self).__init__()
self.clip_by_norm = nn.ClipByNorm(axis=axis)
def construct(self, input_x, max_norm):
norm = self.clip_by_norm(input_x, max_norm)
return norm
class Embedding(nn.Cell):
"""Embedding net definition"""
def __init__(self, vocab_size, embedding_size, padding_idx=None):
super(Embedding, self).__init__()
self.embedding = nn.Embedding(vocab_size=vocab_size, embedding_size=embedding_size,
padding_idx=padding_idx)
def construct(self, index):
res = self.embedding(index)
return res
class EmbeddingLookup(nn.Cell):
"""EmbeddingLookup net definition"""
def __init__(self, vocab_size, embedding_size, max_norm=None):
super(EmbeddingLookup, self).__init__()
self.embedding_lookup = nn.EmbeddingLookup(vocab_size=vocab_size, embedding_size=embedding_size,
max_norm=max_norm)
def construct(self, index):
res = self.embedding_lookup(index)
return res
class CountNonZero(nn.Cell): class CountNonZero(nn.Cell):
"""CountNonZero net definition""" """CountNonZero net definition"""
@ -1082,6 +1120,32 @@ test_case_math_ops = [
'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])), 'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])),
[2, 3], [2, 3]], [2, 3], [2, 3]],
'desc_bprop': [[2, 3]]}), 'desc_bprop': [[2, 3]]}),
('ClipByNorm_1', {
'block': ClipByNorm(),
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)),
Tensor(np.array([0.01]).astype(np.float32))],
'skip': ['backward']}),
('ClipByNorm_2', {
'block': ClipByNorm(axis=0),
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)),
Tensor(np.array([0.01]).astype(np.float32))],
'skip': ['backward']}),
('Embedding_1', {
'block': Embedding(vocab_size=10, embedding_size=3),
'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))],
'skip': ['backward']}),
('Embedding_2', {
'block': Embedding(vocab_size=10, embedding_size=3, padding_idx=2),
'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))],
'skip': ['backward']}),
('EmbeddingLookup_1', {
'block': EmbeddingLookup(vocab_size=10, embedding_size=3),
'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))],
'skip': ['backward']}),
('EmbeddingLookup_2', {
'block': EmbeddingLookup(vocab_size=10, embedding_size=3, max_norm=0.01),
'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))],
'skip': ['backward']}),
('Moments', { ('Moments', {
'block': Moments(axis=(), keep_dims=False), 'block': Moments(axis=(), keep_dims=False),
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))], 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))],

Loading…
Cancel
Save