diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index f12234e4f1..de93d71850 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -805,6 +805,22 @@ def get_bprop_lstm(self): return bprop +@bprop_getters.register(inner.DynamicRNN) +def get_bprop_dynamic_rnn(self): + """Grad definition for `DynamicRNN` operation.""" + dynamic_rnn_grad = G.DynamicRNNGrad(forget_bias=self.forget_bias) + + def bprop(x, w, b, seq_length, init_h, init_c, out, dout): + dy, dh, dc, _, _, _, _, _, = dout + dh = dh[-1] + dc = dc[-1] + y, h, c, i, j, f, o, tanhct = out + dw, db, dx, dh_prev, dc_prev = dynamic_rnn_grad(x, w, b, y, init_h[0], init_c[0], h, + c, dy, dh, dc, i, j, f, o, tanhct) + return dx, dw, db, (0), dh_prev, dc_prev + return bprop + + @bprop_getters.register(P.SigmoidCrossEntropyWithLogits) def get_bprop_sigmoid_crossentropy_with_logits(self): """Grad definition for `SigmoidCrossEntropyWithLogits` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 39f96afa37..82a145f0d6 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -274,6 +274,8 @@ from .basic_lstm_cell import _basic_lstm_cell_tbe from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe +from .dynamic_rnn import _dynamic_rnn_tbe +from .lstm_input_grad import _lstm_input_grad_tbe from .confusion_matrix import _confusion_matrix_tbe from .broadcast_to import _broadcast_to_tbe from .strided_read import _strided_read_tbe diff --git a/mindspore/ops/_op_impl/tbe/dynamic_rnn.py b/mindspore/ops/_op_impl/tbe/dynamic_rnn.py new file mode 100644 index 0000000000..1844a836f8 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/dynamic_rnn.py @@ -0,0 +1,70 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""DynamicRNN op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +dynamic_rnn_op_info = TBERegOp("DynamicRNN") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("dynamic_rnn.so") \ + .compute_cost(10) \ + .kernel_name("dynamic_rnn") \ + .attr("cell_type", "optional", "str", "all", "LSTM") \ + .attr("direction", "optional", "str", "all", "UNIDIRECTIONAL") \ + .attr("cell_depth", "optional", "int", "all", "1") \ + .attr("use_peephole", "optional", "bool", "all", "false") \ + .attr("keep_prob", "optional", "float", "all", "1") \ + .attr("cell_clip", "optional", "float", "all", "-1") \ + .attr("num_proj", "optional", "int", "all", "0") \ + .attr("time_major", "optional", "bool", "all", "false") \ + .attr("forget_bias", "optional", "float", "all", "0") \ + .attr("is_training", "optional", "bool", "all", "true") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "w", False, "required", "all", reshape_type="CN") \ + .input(2, "b", False, "required", "all") \ + .input(3, "seq_length", False, "optional", "all") \ + .input(4, "init_h", False, "optional", "all") \ + .input(5, "init_c", False, "optional", "all") \ + .input(6, "wci", False, "optional", "all") \ + .input(7, "wcf", False, "optional", "all") \ + .input(8, "wco", False, "optional", "all") \ + .input(9, "mask", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .output(1, "output_h", False, "required", "all") \ + .output(2, "output_c", False, "required", "all") \ + .output(3, "i", False, "required", "all") \ + .output(4, "j", False, "required", "all") \ + .output(5, "f", False, "required", "all") \ + .output(6, "o", False, "required", "all") \ + .output(7, "tanhc", False, "required", "all") \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.I32_Default, + DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.I32_Default, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ) \ + .get_op_info() + + +@op_info_register(dynamic_rnn_op_info) +def _dynamic_rnn_tbe(): + """DynamicRNN TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/lstm_input_grad.py b/mindspore/ops/_op_impl/tbe/lstm_input_grad.py new file mode 100644 index 0000000000..1eca0da394 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/lstm_input_grad.py @@ -0,0 +1,51 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LSTMInputGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +lstm_input_grad_op_info = TBERegOp("LSTMInputGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("lstm_input_grad.so") \ + .compute_cost(10) \ + .kernel_name("lstm_input_grad") \ + .partial_flag(True) \ + .input(0, "w", False, "required", "all") \ + .input(1, "init_c", False, "required", "all") \ + .input(2, "c", False, "required", "all") \ + .input(3, "dy", False, "required", "all") \ + .input(4, "dh", False, "required", "all") \ + .input(5, "dc", False, "required", "all") \ + .input(6, "i", False, "required", "all") \ + .input(7, "j", False, "required", "all") \ + .input(8, "f", False, "required", "all") \ + .input(9, "o", False, "required", "all") \ + .input(10, "tanhct", False, "optional", "all") \ + .output(0, "dx", False, "required", "all") \ + .output(1, "dh_prev", False, "required", "all") \ + .output(2, "dc_prev", False, "required", "all") \ + .output(3, "dgate", False, "required", "all") \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ + .get_op_info() + + +@op_info_register(lstm_input_grad_op_info) +def _lstm_input_grad_tbe(): + """LSTMInputGrad TBE register""" + return diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 8df5c55f77..af103f0bd5 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1014,6 +1014,53 @@ class LSTMGrad(PrimitiveWithInfer): return (dy_dtype, dy_dtype, dy_dtype, hx_dtype) +class DynamicRNNGrad(PrimitiveWithInfer): + """Computes the input gradients of DynamicRNN.""" + + @prim_attr_register + def __init__(self, + cell_type='LSTM', + direction='UNIDIRECTIONAL', + cell_depth=0, + use_peephole=False, + keep_prob=-1.0, + cell_clip=-1.0, + num_proj=0, + time_major=False, + forget_bias=0.0): + self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) + self.add_prim_attr("io_format", "ND") + + def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape, + c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape): + validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name) + num_step, batch_size, input_size = x_shape + hidden_size = w_shape[-1] // 4 + if w_shape[-1] % 4 != 0: + raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.") + validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size", + input_size + hidden_size, Rel.EQ, self.name) + valid_shape = [num_step, batch_size, hidden_size] + validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name) + validator.check("y_shape", y_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("c_shape", c_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("i_shape", i_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("j_shape", j_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("f_shape", f_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("o_shape", o_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("tanhc_shape", tanhc_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("dh_shape", dh_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name) + validator.check("dc_shape", dc_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name) + + return w_shape, (w_shape[1],), x_shape, dh_shape, dc_shape + + def infer_dtype(self, x_dtype, w_dtype, b_dtype, y_dtype, init_h_dtype, init_c_dtype, h_dtype, + c_dtype, dy_dtype, dh_dtype, dc_dtype, i_dtype, j_dtype, f_dtype, o_dtype, tanhc_dtype): + return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype + + class PReLUGrad(PrimitiveWithInfer): r""" Gradients of PReLU operation. diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 60d0f86be1..7237b852c5 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -573,3 +573,111 @@ class MatrixSetDiag(PrimitiveWithInfer): x_shape[:-2] + x_shape[-1:], Rel.EQ, self.name) return assist_shape + + +class DynamicRNN(PrimitiveWithInfer): + r""" + DynamicRNN Operator. + + Args: + cell_type (str): An string identifying the cell type in the op. Default: 'LSTM'. + Only 'LSTM' is currently supported. + direction (str): An string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. + Only 'UNIDIRECTIONAL' is currently supported. + cell_depth (int): An integer identifying the cell depth in the op. Default: 1. + use_peephole (bool): An bool identifying if use peephole in the op. Default: False. + keep_prob (float): An float identifying the keep prob in the op. Default: 1.0. + cell_clip (float): An float identifying the cell clip in the op. Default: -1.0. + num_proj (int): An integer identifying the num proj in the op. Default: 0. + time_major (bool): An bool identifying the time major in the op. Default: False. + forget_bias (float): An float identifying the forget bias in the op. Default: 0.0. + is_training (bool): An bool identifying is training in the op. Default: True. + + Inputs: + - **x** (Tensor) - Current words. Tensor of shape :math:`(num_step, batch_size, input_size)`. + The data type must be float16 or float32. + - **w** (Tensor) - Weight. Tensor of shape :math:`(input_size + hidden_size, 4 x hidden_size)`. + The data type must be float16 or float32. + - **b** (Tensor) - Bias. Tensor of shape :math:`(4 x hidden_size)`. + The data type must be float16 or float32. + - **seq_length (Tensor) - The length of each batch. Tensor of shape :math:`(batch_size)`. + Only `None` is currently supported. + - **init_h (Tensor) - Hidden state of initial time. Tensor of shape :math:`(1, batch_size, hidden_size)`. + - **init_c (Tensor) - Cell state of initial time. Tensor of shape :math:`(1, batch_size, hidden_size)`. + + Outputs: + - **y** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. + Has the same type with input `b`. + - **output_h** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. + With data type of float16. + - **output_c** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. + Has the same type with input `b`. + - **i** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. + Has the same type with input `b`. + - **j** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. + Has the same type with input `b`. + - **f** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. + Has the same type with input `b`. + - **o** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. + Has the same type with input `b`. + - **tanhct** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. + Has the same type with input `b`. + + Examples: + >>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16)) + >>> w = Tensor(np.random.rand(96, 128).astype(np.float16)) + >>> b = Tensor(np.random.rand(128).astype(np.float16)) + >>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) + >>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) + >>> dynamic_rnn = P.DynamicRNN() + >>> output = lstm(x, w, b, None, init_h, init_c) + """ + + @prim_attr_register + def __init__(self, + cell_type='LSTM', + direction='UNIDIRECTIONAL', + cell_depth=1, + use_peephole=False, + keep_prob=1.0, + cell_clip=-1.0, + num_proj=0, + time_major=False, + forget_bias=0.0, + is_training=True): + self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) + self.add_prim_attr("io_format", "ND") + + def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape): + validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name) + validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name) + validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name) + validator.check_integer("h_shape", len(h_shape), 3, Rel.EQ, self.name) + validator.check_integer("c_shape", len(c_shape), 3, Rel.EQ, self.name) + if seq_shape is not None: + raise ValueError(f"For {self.name}, seq_shape should be None.") + + num_step, batch_size, input_size = x_shape + hidden_size = w_shape[-1] // 4 + + validator.check("b_shape[-1]", b_shape[-1], "w_shape[-1]", w_shape[-1], Rel.EQ, self.name) + if w_shape[-1] % 4 != 0: + raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.") + validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size", + input_size + hidden_size, Rel.EQ, self.name) + validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name) + validator.check_integer("h_shape[0]", h_shape[0], 1, Rel.EQ, self.name) + validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name) + validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name) + validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name) + + y_shape = (num_step, batch_size, hidden_size) + return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape + + def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype): + validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float32, mstype.float16), self.name) + validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float32, mstype.float16), self.name) + validator.check_tensor_type_same({"b dtype": b_dtype}, (mstype.float32, mstype.float16), self.name) + validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float32, mstype.float16), self.name) + validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float32, mstype.float16), self.name) + return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype