|
|
@ -616,9 +616,10 @@ class Range(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingLookup(PrimitiveWithInfer):
|
|
|
|
class EmbeddingLookup(PrimitiveWithInfer):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar
|
|
|
|
Returns a slice of input tensor based on the specified indices.
|
|
|
|
functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`.
|
|
|
|
|
|
|
|
This primitive runs on the host instead of devices.
|
|
|
|
This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs:
|
|
|
|
|
|
|
|
`offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices.
|
|
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
Inputs:
|
|
|
|
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
|
|
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
|
@ -626,7 +627,6 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
|
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
|
|
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
|
|
|
|
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
|
|
|
|
and the exceeding part will be filled with 0 in the output.
|
|
|
|
and the exceeding part will be filled with 0 in the output.
|
|
|
|
- **axis** (int) - Specifies the dimension index to gather indices.
|
|
|
|
|
|
|
|
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
|
|
|
|
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
|
|
|
|
are equal to `input_indices` minus `offset`.
|
|
|
|
are equal to `input_indices` minus `offset`.
|
|
|
|
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
|
|
|
|
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
|
|
|
@ -641,36 +641,29 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
|
|
|
Examples:
|
|
|
|
Examples:
|
|
|
|
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
|
|
|
|
>>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
|
|
|
|
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
|
|
|
|
>>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
|
|
|
|
>>> axis = 0
|
|
|
|
|
|
|
|
>>> offset = 4
|
|
|
|
>>> offset = 4
|
|
|
|
>>> reduce_scatter_flag = False
|
|
|
|
>>> reduce_scatter_flag = False
|
|
|
|
>>> split_num = 1
|
|
|
|
>>> split_num = 1
|
|
|
|
>>> out = P.EmbeddingLookup()(input_params, input_indices, axis, offset, reduce_scatter_flag, split_num)
|
|
|
|
>>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num)
|
|
|
|
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
|
|
|
|
[[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
@prim_attr_register
|
|
|
|
@prim_attr_register
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
"""init index_select"""
|
|
|
|
"""init index_select"""
|
|
|
|
self.__setattr_flag__ = True
|
|
|
|
self.__setattr_flag__ = True
|
|
|
|
self.init_prim_io_names(inputs=['params', 'indices', 'axis', 'offset', 'reduce_scatter_flag', 'split_num'],
|
|
|
|
self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'],
|
|
|
|
outputs=['output'])
|
|
|
|
outputs=['output'])
|
|
|
|
self.add_prim_attr('primitive_target', 'CPU')
|
|
|
|
self.add_prim_attr('primitive_target', 'CPU')
|
|
|
|
|
|
|
|
|
|
|
|
def __infer__(self, params, indices, axis, offset, reduce_scatter_flag=False, split_num=2):
|
|
|
|
def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2):
|
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
|
|
|
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
|
|
|
validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
|
|
|
|
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
|
|
|
|
|
|
|
|
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
|
|
|
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
|
|
|
validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name)
|
|
|
|
validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name)
|
|
|
|
if split_num['value'] < 1:
|
|
|
|
if split_num['value'] < 1:
|
|
|
|
raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num)
|
|
|
|
raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num)
|
|
|
|
axis_v = axis['value']
|
|
|
|
|
|
|
|
params_shp = params['shape']
|
|
|
|
params_shp = params['shape']
|
|
|
|
rank = len(params_shp)
|
|
|
|
out_shape = indices['shape'] + params_shp[1:]
|
|
|
|
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
|
|
|
|
|
|
|
|
if axis_v < 0:
|
|
|
|
|
|
|
|
axis_v += rank
|
|
|
|
|
|
|
|
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
|
|
|
|
|
|
|
|
if reduce_scatter_flag is None:
|
|
|
|
if reduce_scatter_flag is None:
|
|
|
|
raise ValueError("The value of 'reduce_scatter_flag' is None.")
|
|
|
|
raise ValueError("The value of 'reduce_scatter_flag' is None.")
|
|
|
|
reduce_scatter_flag_value = reduce_scatter_flag['value']
|
|
|
|
reduce_scatter_flag_value = reduce_scatter_flag['value']
|
|
|
|