|
|
@ -5284,18 +5284,18 @@ class CTCLoss(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
Inputs:
|
|
|
|
- **inputs** (Tensor) - The input Tensor must be a `3-D` tensor whose shape is
|
|
|
|
- **inputs** (Tensor) - The input Tensor must be a `3-D` tensor whose shape is
|
|
|
|
:math:`(max_time, batch_size, num_classes)`. `num_classes` must be `num_labels + 1` classes, `num_labels`
|
|
|
|
(`max_time`, `batch_size`, `num_classes`). `num_classes` must be `num_labels + 1` classes, `num_labels`
|
|
|
|
indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
|
|
|
|
indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
|
|
|
|
Data type must be float16, float32 or float64.
|
|
|
|
Data type must be float16, float32 or float64.
|
|
|
|
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
|
|
|
|
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
|
|
|
|
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
|
|
|
|
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
|
|
|
|
- **labels_values** (Tensor) - A `1-D` input tensor. The values are associated with the given batch and time.
|
|
|
|
- **labels_values** (Tensor) - A `1-D` input tensor. The values are associated with the given batch and time.
|
|
|
|
The type must be int32. `labels_values[i]` must in the range of `[0, num_classes)`.
|
|
|
|
The type must be int32. `labels_values[i]` must in the range of `[0, num_classes)`.
|
|
|
|
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`.
|
|
|
|
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of (`batch_size`).
|
|
|
|
The type must be int32. Each value in the tensor must not be greater than `max_time`.
|
|
|
|
The type must be int32. Each value in the tensor must not be greater than `max_time`.
|
|
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
Outputs:
|
|
|
|
- **loss** (Tensor) - A tensor containing log-probabilities, the shape is :math:`(batch_size)`. The tensor has
|
|
|
|
- **loss** (Tensor) - A tensor containing log-probabilities, the shape is (`batch_size`). The tensor has
|
|
|
|
the same type with `inputs`.
|
|
|
|
the same type with `inputs`.
|
|
|
|
- **gradient** (Tensor) - The gradient of `loss`, has the same type and shape with `inputs`.
|
|
|
|
- **gradient** (Tensor) - The gradient of `loss`, has the same type and shape with `inputs`.
|
|
|
|
|
|
|
|
|
|
|
@ -5353,21 +5353,20 @@ class CTCGreedyDecoder(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
Inputs:
|
|
|
|
- **inputs** (Tensor) - The input Tensor must be a `3-D` tensor whose shape is
|
|
|
|
- **inputs** (Tensor) - The input Tensor must be a `3-D` tensor whose shape is
|
|
|
|
:math:`(\text{max_time}, \text{batch_size}, \text{num_classes})`. `num_classes` must be
|
|
|
|
(`max_time`, `batch_size`, `num_classes`). `num_classes` must be `num_labels + 1` classes,
|
|
|
|
`num_labels + 1` classes, `num_labels` indicates the number of actual labels. Blank labels are reserved.
|
|
|
|
`num_labels` indicates the number of actual labels. Blank labels are reserved.
|
|
|
|
Default blank label is `num_classes - 1`. Data type must be float32 or float64.
|
|
|
|
Default blank label is `num_classes - 1`. Data type must be float32 or float64.
|
|
|
|
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of
|
|
|
|
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of (`batch_size`).
|
|
|
|
:math:`(\text{batch_size})`. The type must be int32.
|
|
|
|
The type must be int32. Each value in the tensor must not greater than `max_time`.
|
|
|
|
Each value in the tensor must not greater than `max_time`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
Outputs:
|
|
|
|
- **decoded_indices** (Tensor) - A tensor with shape of :math:`(\text{total_decoded_outputs}, 2)`.
|
|
|
|
- **decoded_indices** (Tensor) - A tensor with shape of (`total_decoded_outputs`, 2).
|
|
|
|
Data type is int64.
|
|
|
|
Data type is int64.
|
|
|
|
- **decoded_values** (Tensor) - A tensor with shape of :math:`(\text{total_decoded_outputs})`,
|
|
|
|
- **decoded_values** (Tensor) - A tensor with shape of (`total_decoded_outputs`),
|
|
|
|
it stores the decoded classes. Data type is int64.
|
|
|
|
it stores the decoded classes. Data type is int64.
|
|
|
|
- **decoded_shape** (Tensor) - The value of tensor is :math:`[\text{batch_size}, \text{max_decoded_legth}]`.
|
|
|
|
- **decoded_shape** (Tensor) - The value of tensor is [`batch_size`, `max_decoded_legth`].
|
|
|
|
Data type is int64.
|
|
|
|
Data type is int64.
|
|
|
|
- **log_probability** (Tensor) - A tensor with shape of :math:`(\text{batch_size}, 1)`,
|
|
|
|
- **log_probability** (Tensor) - A tensor with shape of (`batch_size`, 1),
|
|
|
|
containing sequence log-probability, has the same type as `inputs`.
|
|
|
|
containing sequence log-probability, has the same type as `inputs`.
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
Examples:
|
|
|
@ -5519,6 +5518,129 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|
|
|
return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype)
|
|
|
|
return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DynamicRNN(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
r"""
|
|
|
|
|
|
|
|
DynamicRNN Operator.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
cell_type (str): An string identifying the cell type in the op. Default: 'LSTM'.
|
|
|
|
|
|
|
|
Only 'LSTM' is currently supported.
|
|
|
|
|
|
|
|
direction (str): An string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
|
|
|
|
|
|
|
|
Only 'UNIDIRECTIONAL' is currently supported.
|
|
|
|
|
|
|
|
cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
|
|
|
|
|
|
|
|
use_peephole (bool): An bool identifying if use peephole in the op. Default: False.
|
|
|
|
|
|
|
|
keep_prob (float): An float identifying the keep prob in the op. Default: 1.0.
|
|
|
|
|
|
|
|
cell_clip (float): An float identifying the cell clip in the op. Default: -1.0.
|
|
|
|
|
|
|
|
num_proj (int): An integer identifying the num proj in the op. Default: 0.
|
|
|
|
|
|
|
|
time_major (bool): An bool identifying the time major in the op. Default: True.
|
|
|
|
|
|
|
|
Only `True` is currently supported.
|
|
|
|
|
|
|
|
activation (str): An string identifying the type of activation function in the op. Default: 'tanh'.
|
|
|
|
|
|
|
|
Only 'tanh' is currently supported.
|
|
|
|
|
|
|
|
forget_bias (float): An float identifying the forget bias in the op. Default: 0.0.
|
|
|
|
|
|
|
|
is_training (bool): An bool identifying is training in the op. Default: True.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
|
|
|
- **x** (Tensor) - Current words. Tensor of shape (`num_step`, `batch_size`, `input_size`).
|
|
|
|
|
|
|
|
The data type must be float16 or float32.
|
|
|
|
|
|
|
|
- **w** (Tensor) - Weight. Tensor of shape (`input_size + hidden_size`, `4 x hidden_size`).
|
|
|
|
|
|
|
|
The data type must be float16 or float32.
|
|
|
|
|
|
|
|
- **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`).
|
|
|
|
|
|
|
|
The data type must be float16 or float32.
|
|
|
|
|
|
|
|
- **seq_length (Tensor) - The length of each batch. Tensor of shape (`batch_size`).
|
|
|
|
|
|
|
|
Only `None` is currently supported.
|
|
|
|
|
|
|
|
- **init_h (Tensor) - Hidden state of initial time. Tensor of shape (1, `batch_size`, `hidden_size`).
|
|
|
|
|
|
|
|
- **init_c (Tensor) - Cell state of initial time. Tensor of shape (1, `batch_size`, `hidden_size`).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
|
|
|
- **y** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
|
|
|
|
|
|
|
|
Has the same type with input `b`.
|
|
|
|
|
|
|
|
- **output_h** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
|
|
|
|
|
|
|
|
With data type of float16.
|
|
|
|
|
|
|
|
- **output_c** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
|
|
|
|
|
|
|
|
Has the same type with input `b`.
|
|
|
|
|
|
|
|
- **i** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
|
|
|
|
|
|
|
|
Has the same type with input `b`.
|
|
|
|
|
|
|
|
- **j** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
|
|
|
|
|
|
|
|
Has the same type with input `b`.
|
|
|
|
|
|
|
|
- **f** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
|
|
|
|
|
|
|
|
Has the same type with input `b`.
|
|
|
|
|
|
|
|
- **o** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
|
|
|
|
|
|
|
|
Has the same type with input `b`.
|
|
|
|
|
|
|
|
- **tanhct** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
|
|
|
|
|
|
|
|
Has the same type with input `b`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
>>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16))
|
|
|
|
|
|
|
|
>>> w = Tensor(np.random.rand(96, 128).astype(np.float16))
|
|
|
|
|
|
|
|
>>> b = Tensor(np.random.rand(128).astype(np.float16))
|
|
|
|
|
|
|
|
>>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
|
|
|
|
|
|
|
|
>>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
|
|
|
|
|
|
|
|
>>> dynamic_rnn = P.DynamicRNN()
|
|
|
|
|
|
|
|
>>> output = lstm(x, w, b, None, init_h, init_c)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
|
|
|
cell_type='LSTM',
|
|
|
|
|
|
|
|
direction='UNIDIRECTIONAL',
|
|
|
|
|
|
|
|
cell_depth=1,
|
|
|
|
|
|
|
|
use_peephole=False,
|
|
|
|
|
|
|
|
keep_prob=1.0,
|
|
|
|
|
|
|
|
cell_clip=-1.0,
|
|
|
|
|
|
|
|
num_proj=0,
|
|
|
|
|
|
|
|
time_major=True,
|
|
|
|
|
|
|
|
activation='tanh',
|
|
|
|
|
|
|
|
forget_bias=0.0,
|
|
|
|
|
|
|
|
is_training=True):
|
|
|
|
|
|
|
|
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
|
|
|
|
|
|
|
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
|
|
|
|
|
|
|
|
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
|
|
|
|
|
|
|
self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
|
|
|
|
|
|
|
|
self.num_proj = validator.check_value_type("num_proj", num_proj, [int], self.name)
|
|
|
|
|
|
|
|
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
|
|
|
|
|
|
|
self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name)
|
|
|
|
|
|
|
|
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
|
|
|
|
|
|
|
|
self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
|
|
|
|
|
|
|
|
self.cell_type = validator.check_string("cell_type", cell_type, ['LSTM'], self.name)
|
|
|
|
|
|
|
|
self.direction = validator.check_string("direction", direction, ['UNIDIRECTIONAL'], self.name)
|
|
|
|
|
|
|
|
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
|
|
|
|
|
|
|
|
self.add_prim_attr("io_format", "ND")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
|
|
|
|
|
|
|
|
validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
validator.check_integer("h_shape", len(h_shape), 3, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
validator.check_integer("c_shape", len(c_shape), 3, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
if seq_shape is not None:
|
|
|
|
|
|
|
|
raise ValueError(f"For {self.name}, seq_shape should be None.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_step, batch_size, input_size = x_shape
|
|
|
|
|
|
|
|
hidden_size = w_shape[-1] // 4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
validator.check("b_shape[-1]", b_shape[-1], "w_shape[-1]", w_shape[-1], Rel.EQ, self.name)
|
|
|
|
|
|
|
|
if w_shape[-1] % 4 != 0:
|
|
|
|
|
|
|
|
raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.")
|
|
|
|
|
|
|
|
validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
|
|
|
|
|
|
|
|
input_size + hidden_size, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
|
|
|
|
|
|
|
validator.check_integer("h_shape[0]", h_shape[0], 1, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y_shape = (num_step, batch_size, hidden_size)
|
|
|
|
|
|
|
|
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype):
|
|
|
|
|
|
|
|
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float32, mstype.float16), self.name)
|
|
|
|
|
|
|
|
validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float32, mstype.float16), self.name)
|
|
|
|
|
|
|
|
validator.check_tensor_type_same({"b dtype": b_dtype}, (mstype.float32, mstype.float16), self.name)
|
|
|
|
|
|
|
|
validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float32, mstype.float16), self.name)
|
|
|
|
|
|
|
|
validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float32, mstype.float16), self.name)
|
|
|
|
|
|
|
|
return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InTopK(PrimitiveWithInfer):
|
|
|
|
class InTopK(PrimitiveWithInfer):
|
|
|
|
r"""
|
|
|
|
r"""
|
|
|
|
Whether the targets are in the top `k` predictions.
|
|
|
|
Whether the targets are in the top `k` predictions.
|
|
|
|