move inner.DynamicRNN to P.DynamicRNN.

pull/6991/head
liuxiao93 4 years ago
parent 11a04a889b
commit 34e6368d05

@ -838,7 +838,7 @@ def get_bprop_lstm(self):
return bprop return bprop
@bprop_getters.register(inner.DynamicRNN) @bprop_getters.register(P.DynamicRNN)
def get_bprop_dynamic_rnn(self): def get_bprop_dynamic_rnn(self):
"""Grad definition for `DynamicRNN` operation.""" """Grad definition for `DynamicRNN` operation."""
dynamic_rnn_grad = G.DynamicRNNGrad(forget_bias=self.forget_bias) dynamic_rnn_grad = G.DynamicRNNGrad(forget_bias=self.forget_bias)

@ -29,7 +29,8 @@ dynamic_rnn_op_info = TBERegOp("DynamicRNN") \
.attr("keep_prob", "optional", "float", "all", "1") \ .attr("keep_prob", "optional", "float", "all", "1") \
.attr("cell_clip", "optional", "float", "all", "-1") \ .attr("cell_clip", "optional", "float", "all", "-1") \
.attr("num_proj", "optional", "int", "all", "0") \ .attr("num_proj", "optional", "int", "all", "0") \
.attr("time_major", "optional", "bool", "all", "false") \ .attr("time_major", "optional", "bool", "all", "true") \
.attr("activation", "optional", "str", "all", "tanh") \
.attr("forget_bias", "optional", "float", "all", "0") \ .attr("forget_bias", "optional", "float", "all", "0") \
.attr("is_training", "optional", "bool", "all", "true") \ .attr("is_training", "optional", "bool", "all", "true") \
.partial_flag(True) \ .partial_flag(True) \

@ -71,7 +71,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid, ResizeBilinear, Sigmoid,
SigmoidCrossEntropyWithLogits, SigmoidCrossEntropyWithLogits,
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN,
SoftmaxCrossEntropyWithLogits, ROIAlign, SoftmaxCrossEntropyWithLogits, ROIAlign,
SparseSoftmaxCrossEntropyWithLogits, Tanh, SparseSoftmaxCrossEntropyWithLogits, Tanh,
TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
@ -230,6 +230,7 @@ __all__ = [
'CTCLoss', 'CTCLoss',
'CTCGreedyDecoder', 'CTCGreedyDecoder',
'RNNTLoss', 'RNNTLoss',
'DynamicRNN',
'ReduceAll', 'ReduceAll',
'ReduceAny', 'ReduceAny',
'ScalarToArray', 'ScalarToArray',

@ -451,114 +451,6 @@ class MatrixSetDiag(PrimitiveWithInfer):
return assist_shape return assist_shape
class DynamicRNN(PrimitiveWithInfer):
r"""
DynamicRNN Operator.
Args:
cell_type (str): An string identifying the cell type in the op. Default: 'LSTM'.
Only 'LSTM' is currently supported.
direction (str): An string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
Only 'UNIDIRECTIONAL' is currently supported.
cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
use_peephole (bool): An bool identifying if use peephole in the op. Default: False.
keep_prob (float): An float identifying the keep prob in the op. Default: 1.0.
cell_clip (float): An float identifying the cell clip in the op. Default: -1.0.
num_proj (int): An integer identifying the num proj in the op. Default: 0.
time_major (bool): An bool identifying the time major in the op. Default: False.
forget_bias (float): An float identifying the forget bias in the op. Default: 0.0.
is_training (bool): An bool identifying is training in the op. Default: True.
Inputs:
- **x** (Tensor) - Current words. Tensor of shape :math:`(num_step, batch_size, input_size)`.
The data type must be float16 or float32.
- **w** (Tensor) - Weight. Tensor of shape :math:`(input_size + hidden_size, 4 x hidden_size)`.
The data type must be float16 or float32.
- **b** (Tensor) - Bias. Tensor of shape :math:`(4 x hidden_size)`.
The data type must be float16 or float32.
- **seq_length (Tensor) - The length of each batch. Tensor of shape :math:`(batch_size)`.
Only `None` is currently supported.
- **init_h (Tensor) - Hidden state of initial time. Tensor of shape :math:`(1, batch_size, hidden_size)`.
- **init_c (Tensor) - Cell state of initial time. Tensor of shape :math:`(1, batch_size, hidden_size)`.
Outputs:
- **y** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
- **output_h** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
With data type of float16.
- **output_c** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
- **i** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
- **j** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
- **f** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
- **o** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
- **tanhct** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
Examples:
>>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16))
>>> w = Tensor(np.random.rand(96, 128).astype(np.float16))
>>> b = Tensor(np.random.rand(128).astype(np.float16))
>>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
>>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
>>> dynamic_rnn = P.DynamicRNN()
>>> output = lstm(x, w, b, None, init_h, init_c)
"""
@prim_attr_register
def __init__(self,
cell_type='LSTM',
direction='UNIDIRECTIONAL',
cell_depth=1,
use_peephole=False,
keep_prob=1.0,
cell_clip=-1.0,
num_proj=0,
time_major=False,
forget_bias=0.0,
is_training=True):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name)
validator.check_integer("h_shape", len(h_shape), 3, Rel.EQ, self.name)
validator.check_integer("c_shape", len(c_shape), 3, Rel.EQ, self.name)
if seq_shape is not None:
raise ValueError(f"For {self.name}, seq_shape should be None.")
num_step, batch_size, input_size = x_shape
hidden_size = w_shape[-1] // 4
validator.check("b_shape[-1]", b_shape[-1], "w_shape[-1]", w_shape[-1], Rel.EQ, self.name)
if w_shape[-1] % 4 != 0:
raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.")
validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
input_size + hidden_size, Rel.EQ, self.name)
validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check_integer("h_shape[0]", h_shape[0], 1, Rel.EQ, self.name)
validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name)
validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name)
validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name)
y_shape = (num_step, batch_size, hidden_size)
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape
def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype):
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_type_same({"b dtype": b_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float32, mstype.float16), self.name)
return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
class ConfusionMulGrad(PrimitiveWithInfer): class ConfusionMulGrad(PrimitiveWithInfer):
""" """
`output0` is the dot product result of input0 and input1. `output0` is the dot product result of input0 and input1.

@ -5284,18 +5284,18 @@ class CTCLoss(PrimitiveWithInfer):
Inputs: Inputs:
- **inputs** (Tensor) - The input Tensor must be a `3-D` tensor whose shape is - **inputs** (Tensor) - The input Tensor must be a `3-D` tensor whose shape is
:math:`(max_time, batch_size, num_classes)`. `num_classes` must be `num_labels + 1` classes, `num_labels` (`max_time`, `batch_size`, `num_classes`). `num_classes` must be `num_labels + 1` classes, `num_labels`
indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`. indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
Data type must be float16, float32 or float64. Data type must be float16, float32 or float64.
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]` - **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2. stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
- **labels_values** (Tensor) - A `1-D` input tensor. The values are associated with the given batch and time. - **labels_values** (Tensor) - A `1-D` input tensor. The values are associated with the given batch and time.
The type must be int32. `labels_values[i]` must in the range of `[0, num_classes)`. The type must be int32. `labels_values[i]` must in the range of `[0, num_classes)`.
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`. - **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of (`batch_size`).
The type must be int32. Each value in the tensor must not be greater than `max_time`. The type must be int32. Each value in the tensor must not be greater than `max_time`.
Outputs: Outputs:
- **loss** (Tensor) - A tensor containing log-probabilities, the shape is :math:`(batch_size)`. The tensor has - **loss** (Tensor) - A tensor containing log-probabilities, the shape is (`batch_size`). The tensor has
the same type with `inputs`. the same type with `inputs`.
- **gradient** (Tensor) - The gradient of `loss`, has the same type and shape with `inputs`. - **gradient** (Tensor) - The gradient of `loss`, has the same type and shape with `inputs`.
@ -5353,21 +5353,20 @@ class CTCGreedyDecoder(PrimitiveWithInfer):
Inputs: Inputs:
- **inputs** (Tensor) - The input Tensor must be a `3-D` tensor whose shape is - **inputs** (Tensor) - The input Tensor must be a `3-D` tensor whose shape is
:math:`(\text{max_time}, \text{batch_size}, \text{num_classes})`. `num_classes` must be (`max_time`, `batch_size`, `num_classes`). `num_classes` must be `num_labels + 1` classes,
`num_labels + 1` classes, `num_labels` indicates the number of actual labels. Blank labels are reserved. `num_labels` indicates the number of actual labels. Blank labels are reserved.
Default blank label is `num_classes - 1`. Data type must be float32 or float64. Default blank label is `num_classes - 1`. Data type must be float32 or float64.
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of - **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of (`batch_size`).
:math:`(\text{batch_size})`. The type must be int32. The type must be int32. Each value in the tensor must not greater than `max_time`.
Each value in the tensor must not greater than `max_time`.
Outputs: Outputs:
- **decoded_indices** (Tensor) - A tensor with shape of :math:`(\text{total_decoded_outputs}, 2)`. - **decoded_indices** (Tensor) - A tensor with shape of (`total_decoded_outputs`, 2).
Data type is int64. Data type is int64.
- **decoded_values** (Tensor) - A tensor with shape of :math:`(\text{total_decoded_outputs})`, - **decoded_values** (Tensor) - A tensor with shape of (`total_decoded_outputs`),
it stores the decoded classes. Data type is int64. it stores the decoded classes. Data type is int64.
- **decoded_shape** (Tensor) - The value of tensor is :math:`[\text{batch_size}, \text{max_decoded_legth}]`. - **decoded_shape** (Tensor) - The value of tensor is [`batch_size`, `max_decoded_legth`].
Data type is int64. Data type is int64.
- **log_probability** (Tensor) - A tensor with shape of :math:`(\text{batch_size}, 1)`, - **log_probability** (Tensor) - A tensor with shape of (`batch_size`, 1),
containing sequence log-probability, has the same type as `inputs`. containing sequence log-probability, has the same type as `inputs`.
Examples: Examples:
@ -5519,6 +5518,129 @@ class BasicLSTMCell(PrimitiveWithInfer):
return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype) return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype)
class DynamicRNN(PrimitiveWithInfer):
r"""
DynamicRNN Operator.
Args:
cell_type (str): An string identifying the cell type in the op. Default: 'LSTM'.
Only 'LSTM' is currently supported.
direction (str): An string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
Only 'UNIDIRECTIONAL' is currently supported.
cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
use_peephole (bool): An bool identifying if use peephole in the op. Default: False.
keep_prob (float): An float identifying the keep prob in the op. Default: 1.0.
cell_clip (float): An float identifying the cell clip in the op. Default: -1.0.
num_proj (int): An integer identifying the num proj in the op. Default: 0.
time_major (bool): An bool identifying the time major in the op. Default: True.
Only `True` is currently supported.
activation (str): An string identifying the type of activation function in the op. Default: 'tanh'.
Only 'tanh' is currently supported.
forget_bias (float): An float identifying the forget bias in the op. Default: 0.0.
is_training (bool): An bool identifying is training in the op. Default: True.
Inputs:
- **x** (Tensor) - Current words. Tensor of shape (`num_step`, `batch_size`, `input_size`).
The data type must be float16 or float32.
- **w** (Tensor) - Weight. Tensor of shape (`input_size + hidden_size`, `4 x hidden_size`).
The data type must be float16 or float32.
- **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`).
The data type must be float16 or float32.
- **seq_length (Tensor) - The length of each batch. Tensor of shape (`batch_size`).
Only `None` is currently supported.
- **init_h (Tensor) - Hidden state of initial time. Tensor of shape (1, `batch_size`, `hidden_size`).
- **init_c (Tensor) - Cell state of initial time. Tensor of shape (1, `batch_size`, `hidden_size`).
Outputs:
- **y** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
Has the same type with input `b`.
- **output_h** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
With data type of float16.
- **output_c** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
Has the same type with input `b`.
- **i** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
Has the same type with input `b`.
- **j** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
Has the same type with input `b`.
- **f** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
Has the same type with input `b`.
- **o** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
Has the same type with input `b`.
- **tanhct** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
Has the same type with input `b`.
Examples:
>>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16))
>>> w = Tensor(np.random.rand(96, 128).astype(np.float16))
>>> b = Tensor(np.random.rand(128).astype(np.float16))
>>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
>>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
>>> dynamic_rnn = P.DynamicRNN()
>>> output = lstm(x, w, b, None, init_h, init_c)
"""
@prim_attr_register
def __init__(self,
cell_type='LSTM',
direction='UNIDIRECTIONAL',
cell_depth=1,
use_peephole=False,
keep_prob=1.0,
cell_clip=-1.0,
num_proj=0,
time_major=True,
activation='tanh',
forget_bias=0.0,
is_training=True):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
self.num_proj = validator.check_value_type("num_proj", num_proj, [int], self.name)
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name)
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
self.cell_type = validator.check_string("cell_type", cell_type, ['LSTM'], self.name)
self.direction = validator.check_string("direction", direction, ['UNIDIRECTIONAL'], self.name)
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name)
validator.check_integer("h_shape", len(h_shape), 3, Rel.EQ, self.name)
validator.check_integer("c_shape", len(c_shape), 3, Rel.EQ, self.name)
if seq_shape is not None:
raise ValueError(f"For {self.name}, seq_shape should be None.")
num_step, batch_size, input_size = x_shape
hidden_size = w_shape[-1] // 4
validator.check("b_shape[-1]", b_shape[-1], "w_shape[-1]", w_shape[-1], Rel.EQ, self.name)
if w_shape[-1] % 4 != 0:
raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.")
validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
input_size + hidden_size, Rel.EQ, self.name)
validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check_integer("h_shape[0]", h_shape[0], 1, Rel.EQ, self.name)
validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name)
validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name)
validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name)
y_shape = (num_step, batch_size, hidden_size)
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape
def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype):
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_type_same({"b dtype": b_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float32, mstype.float16), self.name)
return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
class InTopK(PrimitiveWithInfer): class InTopK(PrimitiveWithInfer):
r""" r"""
Whether the targets are in the top `k` predictions. Whether the targets are in the top `k` predictions.

Loading…
Cancel
Save