diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc index 6b123c6359..61726eeb99 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc @@ -59,7 +59,7 @@ const AnfNodePtr ClipByNormNoDivSquareSumFusion::Process(const FuncGraphPtr &gra auto prim = std::make_shared(kClipByNormNoDivSumOpName); MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = {NewValueNode(prim), input, constant_select, constant_greater, constant_maximum}; + std::vector inputs = {NewValueNode(prim), input, constant_greater, constant_select, constant_maximum}; auto fusion_node = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(fusion_node); auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 5e9283fed7..5c7ce17d8c 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -260,6 +260,7 @@ def _is_float_dtype(dtype): return True return False + class ClipByNorm(Cell): r""" Clips tensor values to a maximum :math:`L_2`-norm. @@ -271,6 +272,9 @@ class ClipByNorm(Cell): \text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)}, where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`. + Args: + axis (Union[None, int, tuple(int)): Compute the L2-norm along the Specific dimension. + Default: None, all dimensions to calculate. Inputs: - **input** (Tensor) - Tensor of shape N-D. The type must be float32 or float16. @@ -287,8 +291,14 @@ class ClipByNorm(Cell): """ - def __init__(self): + def __init__(self, axis=None): super(ClipByNorm, self).__init__() + if axis is None: + axis = () + if isinstance(axis, tuple): + for idx, item in enumerate(axis): + Validator.check_value_type("axis[%d]" % idx, item, [int], self.cls_name) + self.axis = Validator.check_value_type('axis', axis, [int, tuple], self.cls_name) self.reduce_sum = P.ReduceSum(keep_dims=True) self.select_ = P.Select() self.greater_ = P.Greater() @@ -305,7 +315,7 @@ class ClipByNorm(Cell): def construct(self, x, clip_norm): """add ms_function decorator for pynative mode""" mul_x = F.square(x) - l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32) + l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32) cond = self.greater_(l2sum, 0) ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) @@ -318,7 +328,9 @@ class ClipByNorm(Cell): intermediate = x * clip_norm max_norm = self.max_op(l2norm, clip_norm) - values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1) + if self.axis is None: + max_norm = self.expand_dims(max_norm, -1) + values_clip = self.cast(intermediate, mstype.float32) / max_norm values_clip = self.reshape(values_clip, self.shape(x)) values_clip = identity(values_clip) return values_clip diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index 479ed4dea8..943fe274a2 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -16,16 +16,21 @@ import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor from mindspore.ops import operations as P +from mindspore.ops import functional as F from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore.communication.management import get_group_size from mindspore.context import ParallelMode from mindspore.parallel._utils import _get_parallel_mode +from mindspore._checkparam import Rel from mindspore._checkparam import Validator as validator +from mindspore.ops.primitive import constexpr +from .basic import ClipByNorm from ..cell import Cell __all__ = ['Embedding', 'EmbeddingLookup'] + class Embedding(Cell): r""" A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -45,7 +50,8 @@ class Embedding(Cell): Refer to class `initializer` for the values of string when a string is specified. Default: 'normal'. dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32. - + padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index + will be initialized to zero. Default: None. The feature is inactivated. Inputs: - **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The elements of the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will @@ -63,16 +69,24 @@ class Embedding(Cell): >>> output.shape (8, 128, 768) """ - def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32): + + def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', + dtype=mstype.float32, padding_idx=None): super(Embedding, self).__init__() - validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) + self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) + self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name) - self.vocab_size = vocab_size - self.embedding_size = embedding_size + validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) self.use_one_hot = use_one_hot - self.embedding_table = Parameter(initializer(embedding_table, [vocab_size, embedding_size]), - name='embedding_table') self.dtype = dtype + self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size]) + self.padding_idx = padding_idx + if padding_idx is not None: + self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH, + "padding_idx", self.cls_name) + self.init_tensor = self.init_tensor.to_tensor().asnumpy() + self.init_tensor[self.padding_idx] = 0 + self.embedding_table = Parameter(self.init_tensor, name='embedding_table') self.expand = P.ExpandDims() self.reshape_flat = P.Reshape() self.shp_flat = (-1,) @@ -99,16 +113,17 @@ class Embedding(Cell): return output def extend_repr(self): - s = 'vocab_size={}, embedding_size={},' \ - 'use_one_hot={}, ' \ - 'embedding_table={}, dtype={}'.format( - self.vocab_size, - self.embedding_size, - self.use_one_hot, - self.embedding_table, - self.dtype) + s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format( + self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx) return s + +@constexpr +def _make_axis_range(start, end): + axis = tuple(range(start, end)) + return axis + + class EmbeddingLookup(Cell): r""" Returns a slice of input tensor based on the specified indices. @@ -120,8 +135,7 @@ class EmbeddingLookup(Cell): When 'target' is set to 'DEVICE', this module will use P.GatherV2() which specified 'axis = 0' to lookup table. In field slice mode, the manual_shapes must be given. It is a tuple ,where - the element is vocab[i], vocab[i] is the row numbers for i-th - part. + the element is vocab[i], vocab[i] is the row numbers for i-th part. Args: vocab_size (int): Size of the dictionary of embeddings. @@ -132,6 +146,8 @@ class EmbeddingLookup(Cell): slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE. manual_shapes (tuple): The accompaniment array in field slice mode. + max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 + or None. Default: None Inputs: - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. @@ -152,7 +168,7 @@ class EmbeddingLookup(Cell): TABLE_COLUMN_SLICE = "table_column_slice" def __init__(self, vocab_size, embedding_size, param_init='normal', - target='CPU', slice_mode='batch_slice', manual_shapes=None): + target='CPU', slice_mode='batch_slice', manual_shapes=None, max_norm=None): super(EmbeddingLookup, self).__init__() self.target = target if target not in ('CPU', 'DEVICE'): @@ -160,7 +176,9 @@ class EmbeddingLookup(Cell): + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') self.gatherv2 = P.GatherV2() self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') - self.embedding_table = Parameter(initializer(param_init, [vocab_size, embedding_size]), + self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) + self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) + self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), name='embedding_table') parallel_mode = _get_parallel_mode() is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) @@ -188,10 +206,18 @@ class EmbeddingLookup(Cell): if is_auto_parallel: raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get " + str(slice_mode)) + self.max_norm = max_norm + if self.max_norm is not None: + self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) + self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) def construct(self, indices): if self.target == "CPU": out = self.embeddinglookup(self.embedding_table, indices, 0) else: out = self.gatherv2(self.embedding_table, indices, 0) + if self.max_norm is not None: + axis = _make_axis_range(F.rank(indices), F.rank(out)) + clip_by_norm = ClipByNorm(axis) + out = clip_by_norm(out, self.max_norm) return out diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index ad717bc227..e4193c1643 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -228,6 +228,44 @@ class Moments(nn.Cell): return mean, variance +class ClipByNorm(nn.Cell): + """ClipByNorm net definition""" + + def __init__(self, axis=None): + super(ClipByNorm, self).__init__() + self.clip_by_norm = nn.ClipByNorm(axis=axis) + + def construct(self, input_x, max_norm): + norm = self.clip_by_norm(input_x, max_norm) + return norm + + +class Embedding(nn.Cell): + """Embedding net definition""" + + def __init__(self, vocab_size, embedding_size, padding_idx=None): + super(Embedding, self).__init__() + self.embedding = nn.Embedding(vocab_size=vocab_size, embedding_size=embedding_size, + padding_idx=padding_idx) + + def construct(self, index): + res = self.embedding(index) + return res + + +class EmbeddingLookup(nn.Cell): + """EmbeddingLookup net definition""" + + def __init__(self, vocab_size, embedding_size, max_norm=None): + super(EmbeddingLookup, self).__init__() + self.embedding_lookup = nn.EmbeddingLookup(vocab_size=vocab_size, embedding_size=embedding_size, + max_norm=max_norm) + + def construct(self, index): + res = self.embedding_lookup(index) + return res + + class CountNonZero(nn.Cell): """CountNonZero net definition""" @@ -1082,6 +1120,32 @@ test_case_math_ops = [ 'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])), [2, 3], [2, 3]], 'desc_bprop': [[2, 3]]}), + ('ClipByNorm_1', { + 'block': ClipByNorm(), + 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)), + Tensor(np.array([0.01]).astype(np.float32))], + 'skip': ['backward']}), + ('ClipByNorm_2', { + 'block': ClipByNorm(axis=0), + 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32)), + Tensor(np.array([0.01]).astype(np.float32))], + 'skip': ['backward']}), + ('Embedding_1', { + 'block': Embedding(vocab_size=10, embedding_size=3), + 'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))], + 'skip': ['backward']}), + ('Embedding_2', { + 'block': Embedding(vocab_size=10, embedding_size=3, padding_idx=2), + 'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))], + 'skip': ['backward']}), + ('EmbeddingLookup_1', { + 'block': EmbeddingLookup(vocab_size=10, embedding_size=3), + 'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))], + 'skip': ['backward']}), + ('EmbeddingLookup_2', { + 'block': EmbeddingLookup(vocab_size=10, embedding_size=3, max_norm=0.01), + 'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))], + 'skip': ['backward']}), ('Moments', { 'block': Moments(axis=(), keep_dims=False), 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))],