move embeddinglookup to the internal

pull/2163/head
Xiaoda Zhang 5 years ago
parent 7038df8b99
commit 55e7d9d2b8

@ -191,7 +191,7 @@ def get_bprop_tile(self):
return bprop
@bprop_getters.register(P.EmbeddingLookup)
@bprop_getters.register(inner.EmbeddingLookup)
def get_bprop_embedding_lookup(self):
"""Generate bprop for EmbeddingLookup"""
host_sub = P.Sub().add_prim_attr('primitive_target', 'CPU')

@ -26,7 +26,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, EmbeddingLookup,
Shape, Size, Slice, Split,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
@ -138,7 +138,6 @@ __all__ = [
'ReduceSum',
'ReduceMean',
'LayerNorm',
'EmbeddingLookup',
'Rank',
'Less',
'LessEqual',

@ -258,3 +258,73 @@ class AscendDequant(PrimitiveWithInfer):
validator.check_type_name("x", x_type, [mstype.int32], self.name)
validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name)
return mstype.float16
class EmbeddingLookup(PrimitiveWithInfer):
"""
Returns a slice of input tensor based on the specified indices.
This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs:
`offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The Tensor slice, instead of the entire Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
and the exceeding part will be filled with 0 in the output.
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
are equal to `input_indices` minus `offset`.
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
Only constant value is allowed.
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable
is used only if `reduce_scatter_flag` is True. Only constant value is allowed.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
>>> offset = 4
>>> reduce_scatter_flag = False
>>> split_num = 1
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num)
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
"""
@prim_attr_register
def __init__(self):
"""init index_select"""
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'],
outputs=['output'])
self.add_prim_attr('primitive_target', 'CPU')
def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name)
if split_num['value'] < 1:
raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num)
params_shp = params['shape']
out_shape = indices['shape'] + params_shp[1:]
if reduce_scatter_flag is None:
raise ValueError("The value of 'reduce_scatter_flag' is None.")
reduce_scatter_flag_value = reduce_scatter_flag['value']
if split_num is None:
raise ValueError("The value of 'split_num_value' is None.")
split_num_value = split_num['value']
if reduce_scatter_flag_value is True:
# Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by
# (split_num * 8)
if out_shape[0] % (split_num_value * 8) != 0:
raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." %
(out_shape[0], (split_num_value * 8)))
# After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8
out_shape[0] = out_shape[0] // 8
out = {'shape': out_shape,
'dtype': params['dtype'],
'value': None}
return out

@ -558,76 +558,6 @@ class SparseGatherV2(GatherV2):
"""
class EmbeddingLookup(PrimitiveWithInfer):
"""
Returns a slice of input tensor based on the specified indices.
This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs:
`offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The Tensor slice, instead of the entire Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
and the exceeding part will be filled with 0 in the output.
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
are equal to `input_indices` minus `offset`.
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
Only constant value is allowed.
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable
is used only if `reduce_scatter_flag` is True. Only constant value is allowed.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
>>> offset = 4
>>> reduce_scatter_flag = False
>>> split_num = 1
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num)
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
"""
@prim_attr_register
def __init__(self):
"""init index_select"""
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'],
outputs=['output'])
self.add_prim_attr('primitive_target', 'CPU')
def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name)
if split_num['value'] < 1:
raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num)
params_shp = params['shape']
out_shape = indices['shape'] + params_shp[1:]
if reduce_scatter_flag is None:
raise ValueError("The value of 'reduce_scatter_flag' is None.")
reduce_scatter_flag_value = reduce_scatter_flag['value']
if split_num is None:
raise ValueError("The value of 'split_num_value' is None.")
split_num_value = split_num['value']
if reduce_scatter_flag_value is True:
# Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by
# (split_num * 8)
if out_shape[0] % (split_num_value * 8) != 0:
raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." %
(out_shape[0], (split_num_value * 8)))
# After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8
out_shape[0] = out_shape[0] // 8
out = {'shape': out_shape,
'dtype': params['dtype'],
'value': None}
return out
class Split(PrimitiveWithInfer):
"""
Splits input tensor into output_num of tensors along the given axis and output numbers.

@ -19,6 +19,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import _executor
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from tests.ut.python.ops.test_math_ops import VirtualLoss
@ -39,7 +40,7 @@ class Net(nn.Cell):
self.offset = offset
self.reduce_scatter_flag = reduce_scatter_flag
self.split_num = split_num
self.elu = P.EmbeddingLookup()
self.elu = inner.EmbeddingLookup()
self.mm = P.BatchMatMul()
def construct(self, x, y):

Loading…
Cancel
Save