|
|
|
@ -576,19 +576,21 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar
|
|
|
|
|
functionality as GatherV2, but has three more inputs: `offset`, `reduce_scatter_flag` and `split_num`.
|
|
|
|
|
This primitive runs on the host instead of devices.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
|
|
|
The Tensor slice, instead of the entire Tensor.
|
|
|
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
|
|
|
Specifies the indices of elements of the original Tensor. Must be in the range
|
|
|
|
|
`[0, input_param.shape()[axis])`.
|
|
|
|
|
Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
|
|
|
|
|
and the exceeding part will be filled with 0 in the output.
|
|
|
|
|
- **axis** (int) - Specifies the dimension index to gather indices.
|
|
|
|
|
- **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
|
|
|
|
|
are equal to `input_indices` minus `offset`.
|
|
|
|
|
- **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
|
|
|
|
|
Only constant value is allowed.
|
|
|
|
|
- **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable
|
|
|
|
|
is used only if `reduce_scatter_flag` is True.
|
|
|
|
|
is used only if `reduce_scatter_flag` is True. Only constant value is allowed.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
@ -627,12 +629,20 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
|
|
|
|
if axis_v < 0:
|
|
|
|
|
axis_v += rank
|
|
|
|
|
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
|
|
|
|
|
if reduce_scatter_flag:
|
|
|
|
|
# partition the tensor along the dimension 0.
|
|
|
|
|
if out_shape[0] % split_num['value'] != 0:
|
|
|
|
|
raise ValueError("The dimension 0 of the shape: %d, is not divisible by split_num: %d." %
|
|
|
|
|
(out_shape[0], split_num['value']))
|
|
|
|
|
out_shape[0] = out_shape[0] // split_num['value']
|
|
|
|
|
if reduce_scatter_flag is None:
|
|
|
|
|
raise ValueError("The value of 'reduce_scatter_flag' is None.")
|
|
|
|
|
reduce_scatter_flag_value = reduce_scatter_flag['value']
|
|
|
|
|
if split_num is None:
|
|
|
|
|
raise ValueError("The value of 'split_num_value' is None.")
|
|
|
|
|
split_num_value = split_num['value']
|
|
|
|
|
if reduce_scatter_flag_value is True:
|
|
|
|
|
# Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by
|
|
|
|
|
# (split_num * 8)
|
|
|
|
|
if out_shape[0] % (split_num_value * 8) != 0:
|
|
|
|
|
raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." %
|
|
|
|
|
(out_shape[0], (split_num_value * 8)))
|
|
|
|
|
# After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8
|
|
|
|
|
out_shape[0] = out_shape[0] // 8
|
|
|
|
|
out = {'shape': out_shape,
|
|
|
|
|
'dtype': params['dtype'],
|
|
|
|
|
'value': None}
|
|
|
|
|