!1372 [Auto parallel] Add a new primitive EmbeddingLookup

Merge pull request !1372 from Xiaoda/add-embedinglookup-primitive
pull/1372/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 57874cd61f

@ -26,7 +26,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range,
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
Shape, Size, Slice, Split, EmbeddingLookup,
Squeeze, StridedSlice, Tile,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
@ -131,6 +131,7 @@ __all__ = [
'ReduceMean',
'Range',
'LayerNorm',
'EmbeddingLookup',
'Rank',
'Less',
'LessEqual',

@ -572,6 +572,73 @@ class Range(PrimitiveWithInfer):
return x_dtype
class EmbeddingLookup(PrimitiveWithInfer):
"""
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar
functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The Tensor slice, instead of the entire Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Must be in the range
`[0, input_param.shape()[axis])`.
- **axis** (int) - Specifies the dimension index to gather indices.
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
are equal to `input_indices` minus `offset`.
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable
is used only if `reduce_scatter_flag` is True.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
>>> axis = 0
>>> offset = 4
>>> reduce_scatter_flag = False
>>> split_num = 1
>>> out = P.EmbeddingLookup()(input_params, input_indices, axis, offset, reduce_scatter_flag, split_num)
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
"""
@prim_attr_register
def __init__(self):
"""init index_select"""
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['params', 'indices', 'axis', 'offset', 'reduce_scatter_flag', 'split_num'],
outputs=['output'])
self.add_prim_attr('target', 'CPU')
def __infer__(self, params, indices, axis, offset, reduce_scatter_flag=False, split_num=2):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name)
if split_num['value'] < 1:
raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num)
axis_v = axis['value']
params_shp = params['shape']
rank = len(params_shp)
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
if axis_v < 0:
axis_v += rank
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
if reduce_scatter_flag:
# partition the tensor along the dimension 0.
if out_shape[0] % split_num['value'] != 0:
raise ValueError("The dimension 0 of the shape: %d, is not divisible by split_num: %d." %
(out_shape[0], split_num['value']))
out_shape[0] = out_shape[0] // split_num['value']
out = {'shape': out_shape,
'dtype': params['dtype'],
'value': None}
return out
class Split(PrimitiveWithInfer):
"""
Splits input tensor into output_num of tensors along the given axis and output numbers.

@ -0,0 +1,79 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from tests.ut.python.ops.test_math_ops import VirtualLoss
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
class Net(nn.Cell):
def __init__(self, shape, axis, offset, reduce_scatter_flag, split_num):
super().__init__()
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.axis = axis
self.offset = offset
self.reduce_scatter_flag = reduce_scatter_flag
self.split_num = split_num
self.elu = P.EmbeddingLookup()
self.mm = P.BatchMatMul()
def construct(self, x, y):
out = self.elu(x, self.index, self.axis, self.offset, self.reduce_scatter_flag, self.split_num)
out = self.mm(out, y)
return out
def test_embeddinglookup_reducescatter_false():
shape = [8, 8]
axis = 0
offset = 8
reduce_scatter_flag = False
split_num = 1
net = NetWithLoss(Net(shape, axis, offset, reduce_scatter_flag, split_num))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_embeddinglookup_reducescatter_true():
shape = [8, 8]
axis = 0
offset = 8
reduce_scatter_flag = True
split_num = 8
net = NetWithLoss(Net(shape, axis, offset, reduce_scatter_flag, split_num))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([1, 32, 8]), dtype=ms.float32)
_executor.compile(net, x, y)
Loading…
Cancel
Save