From 4154adf196b9f04a1e47bf6f0727fe86a6a4f4db Mon Sep 17 00:00:00 2001 From: Xiaoda Zhang Date: Fri, 22 May 2020 18:35:15 +0800 Subject: [PATCH] add embedinglookup primitive --- mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/array_ops.py | 67 ++++++++++++++++ .../python/parallel/test_embeddinglookup.py | 79 +++++++++++++++++++ 3 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 tests/ut/python/parallel/test_embeddinglookup.py diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 762460d3e7..6f6c632f58 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -26,7 +26,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range, SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, - Shape, Size, Slice, Split, + Shape, Size, Slice, Split, EmbeddingLookup, Squeeze, StridedSlice, Tile, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, @@ -131,6 +131,7 @@ __all__ = [ 'ReduceMean', 'Range', 'LayerNorm', + 'EmbeddingLookup', 'Rank', 'Less', 'LessEqual', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 19a9ffd79d..63b4e0d001 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -572,6 +572,73 @@ class Range(PrimitiveWithInfer): return x_dtype +class EmbeddingLookup(PrimitiveWithInfer): + """ + Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar + functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`. + + Inputs: + - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + The Tensor slice, instead of the entire Tensor. + - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. + Specifies the indices of elements of the original Tensor. Must be in the range + `[0, input_param.shape()[axis])`. + - **axis** (int) - Specifies the dimension index to gather indices. + - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices + are equal to `input_indices` minus `offset`. + - **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not. + - **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable + is used only if `reduce_scatter_flag` is True. + + + Outputs: + Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. + + Examples: + >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) + >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) + >>> axis = 0 + >>> offset = 4 + >>> reduce_scatter_flag = False + >>> split_num = 1 + >>> out = P.EmbeddingLookup()(input_params, input_indices, axis, offset, reduce_scatter_flag, split_num) + [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] + """ + @prim_attr_register + def __init__(self): + """init index_select""" + self.__setattr_flag__ = True + self.init_prim_io_names(inputs=['params', 'indices', 'axis', 'offset', 'reduce_scatter_flag', 'split_num'], + outputs=['output']) + self.add_prim_attr('target', 'CPU') + + def __infer__(self, params, indices, axis, offset, reduce_scatter_flag=False, split_num=2): + validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) + validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) + validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) + validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) + validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name) + if split_num['value'] < 1: + raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num) + axis_v = axis['value'] + params_shp = params['shape'] + rank = len(params_shp) + validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) + if axis_v < 0: + axis_v += rank + out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] + if reduce_scatter_flag: + # partition the tensor along the dimension 0. + if out_shape[0] % split_num['value'] != 0: + raise ValueError("The dimension 0 of the shape: %d, is not divisible by split_num: %d." % + (out_shape[0], split_num['value'])) + out_shape[0] = out_shape[0] // split_num['value'] + out = {'shape': out_shape, + 'dtype': params['dtype'], + 'value': None} + return out + + class Split(PrimitiveWithInfer): """ Splits input tensor into output_num of tensors along the given axis and output numbers. diff --git a/tests/ut/python/parallel/test_embeddinglookup.py b/tests/ut/python/parallel/test_embeddinglookup.py new file mode 100644 index 0000000000..9b7e36c6f1 --- /dev/null +++ b/tests/ut/python/parallel/test_embeddinglookup.py @@ -0,0 +1,79 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.loss(predict) + +class Net(nn.Cell): + def __init__(self, shape, axis, offset, reduce_scatter_flag, split_num): + super().__init__() + self.index = Tensor(np.ones(shape), dtype=ms.int32) + self.axis = axis + self.offset = offset + self.reduce_scatter_flag = reduce_scatter_flag + self.split_num = split_num + self.elu = P.EmbeddingLookup() + self.mm = P.BatchMatMul() + + def construct(self, x, y): + out = self.elu(x, self.index, self.axis, self.offset, self.reduce_scatter_flag, self.split_num) + out = self.mm(out, y) + return out + + +def test_embeddinglookup_reducescatter_false(): + shape = [8, 8] + axis = 0 + offset = 8 + reduce_scatter_flag = False + split_num = 1 + net = NetWithLoss(Net(shape, axis, offset, reduce_scatter_flag, split_num)) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_embeddinglookup_reducescatter_true(): + shape = [8, 8] + axis = 0 + offset = 8 + reduce_scatter_flag = True + split_num = 8 + net = NetWithLoss(Net(shape, axis, offset, reduce_scatter_flag, split_num)) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([1, 32, 8]), dtype=ms.float32) + _executor.compile(net, x, y)