diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h index e375fa9ff7..3a49b5d3c7 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h @@ -187,6 +187,8 @@ constexpr const char kNameBasicLSTMCellWeightGrad[] = "BasicLSTMCellWeightGrad"; constexpr const char kNameBasicLSTMCellCStateGrad[] = "BasicLSTMCellCStateGrad"; constexpr const char kNameDynamicRNN[] = "DynamicRNN"; constexpr const char kNameDynamicRNNGrad[] = "DynamicRNNGrad"; +constexpr const char kNameDynamicGRUV2[] = "DynamicGRUV2"; +constexpr const char kNameDynamicGRUV2Grad[] = "DynamicGRUV2Grad"; constexpr const char kNameL2Loss[] = "L2Loss"; constexpr const char kNameCTCLoss[] = "CTCLoss"; constexpr const char kNameRange[] = "Range"; diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.cc index 484a76a54a..a008d919ed 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.cc @@ -92,4 +92,42 @@ OUTPUT_MAP(DynamicRNNGrad) = {{0, OUTPUT_DESC(dw)}, {3, OUTPUT_DESC(dh_prev)}, {4, OUTPUT_DESC(dc_prev)}}; REG_ADPT_DESC(DynamicRNNGrad, kNameDynamicRNNGrad, ADPT_DESC(DynamicRNNGrad)) + +// DynamicGRUV2 +INPUT_MAP(DynamicGRUV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(weight_input)}, {3, INPUT_DESC(weight_hidden)}, + {4, INPUT_DESC(bias_input)}, {5, INPUT_DESC(bias_hidden)}, {6, INPUT_DESC(seq_length)}, + {7, INPUT_DESC(init_h)}}; +ATTR_MAP(DynamicGRUV2) = {{"direction", ATTR_DESC(direction, AnyTraits())}, + {"cell_depth", ATTR_DESC(cell_depth, AnyTraits())}, + {"keep_prob", ATTR_DESC(keep_prob, AnyTraits())}, + {"cell_clip", ATTR_DESC(cell_clip, AnyTraits())}, + {"num_proj", ATTR_DESC(num_proj, AnyTraits())}, + {"time_major", ATTR_DESC(time_major, AnyTraits())}, + {"activation", ATTR_DESC(direction, AnyTraits())}, + {"gate_order", ATTR_DESC(gate_order, AnyTraits())}, + {"reset_after", ATTR_DESC(reset_after, AnyTraits())}, + {"is_training", ATTR_DESC(is_training, AnyTraits())}}; +OUTPUT_MAP(DynamicGRUV2) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(output_h)}, {2, OUTPUT_DESC(update)}, + {3, OUTPUT_DESC(reset)}, {4, OUTPUT_DESC(new)}, {5, OUTPUT_DESC(hidden_new)}}; +REG_ADPT_DESC(DynamicGRUV2, kNameDynamicGRUV2, ADPT_DESC(DynamicGRUV2)) + +// DynamicGRUV2Grad +INPUT_MAP(DynamicGRUV2Grad) = { + {1, INPUT_DESC(x)}, {2, INPUT_DESC(weight_input)}, {3, INPUT_DESC(weight_hidden)}, + {4, INPUT_DESC(y)}, {5, INPUT_DESC(init_h)}, {6, INPUT_DESC(h)}, + {7, INPUT_DESC(dy)}, {8, INPUT_DESC(dh)}, {9, INPUT_DESC(update)}, + {10, INPUT_DESC(reset)}, {11, INPUT_DESC(new)}, {12, INPUT_DESC(hidden_new)}, + {13, INPUT_DESC(seq_length)}, {14, INPUT_DESC(mask)}}; +ATTR_MAP(DynamicGRUV2Grad) = {{"direction", ATTR_DESC(direction, AnyTraits())}, + {"cell_depth", ATTR_DESC(cell_depth, AnyTraits())}, + {"keep_prob", ATTR_DESC(keep_prob, AnyTraits())}, + {"cell_clip", ATTR_DESC(cell_clip, AnyTraits())}, + {"num_proj", ATTR_DESC(num_proj, AnyTraits())}, + {"time_major", ATTR_DESC(time_major, AnyTraits())}, + {"bias_type", ATTR_DESC(bias_type, AnyTraits())}, + {"gate_order", ATTR_DESC(gate_order, AnyTraits())}, + {"reset_after", ATTR_DESC(reset_after, AnyTraits())}}; +OUTPUT_MAP(DynamicGRUV2Grad) = {{0, OUTPUT_DESC(dw_input)}, {1, OUTPUT_DESC(dw_hidden)}, {2, OUTPUT_DESC(db_input)}, + {3, OUTPUT_DESC(db_hidden)}, {4, OUTPUT_DESC(dx)}, {5, OUTPUT_DESC(dh_prev)}}; +REG_ADPT_DESC(DynamicGRUV2Grad, kNameDynamicGRUV2Grad, ADPT_DESC(DynamicGRUV2Grad)) } // namespace mindspore::transform diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.h index 0939fdb131..dc9f6c43d1 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.h @@ -40,5 +40,11 @@ DECLARE_OP_USE_OUTPUT(DynamicRNN) DECLARE_OP_ADAPTER(DynamicRNNGrad) DECLARE_OP_USE_OUTPUT(DynamicRNNGrad) + +DECLARE_OP_ADAPTER(DynamicGRUV2) +DECLARE_OP_USE_OUTPUT(DynamicGRUV2) + +DECLARE_OP_ADAPTER(DynamicGRUV2Grad) +DECLARE_OP_USE_OUTPUT(DynamicGRUV2Grad) } // namespace mindspore::transform #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_RNN_DECLARE_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 83b6ff2fb6..7342941545 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -225,6 +225,8 @@ constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad"; constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell"; constexpr auto kDynamicRNNOpName = "DynamicRNN"; constexpr auto kLSTMInputGradOpName = "LSTMInputGrad"; +constexpr auto kDynamicGRUOpName = "DynamicGRU"; +constexpr auto kGRUV2HiddenGrad = "GRUV2HiddenGrad"; constexpr auto kFusedSparseFtrlName = "FusedSparseFtrl"; constexpr auto kFusedSparseProximalAdagradName = "FusedSparseProximalAdagrad"; constexpr auto kFusedSparseLazyAdamName = "FusedSparseLazyAdam"; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index db8fa1ab6f..a5490b892c 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -110,6 +110,8 @@ inline const PrimitivePtr kPrimUniqueGrad = std::make_shared("UniqueG inline const PrimitivePtr kPrimExtractImagePatches = std::make_shared("ExtractImagePatches"); inline const PrimitivePtr kPrimDynamicRNN = std::make_shared("DynamicRNN"); inline const PrimitivePtr kPrimDynamicRNNGrad = std::make_shared("DynamicRNNGrad"); +inline const PrimitivePtr kPrimDynamicGRUV2 = std::make_shared("DynamicGRUV2"); +inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared("DynamicGRUV2Grad"); inline const PrimitivePtr kPrimScatterAdd = std::make_shared("ScatterAdd"); inline const PrimitivePtr kPrimScatterUpdate = std::make_shared("ScatterUpdate"); inline const PrimitivePtr kPrimDiv = std::make_shared("Div"); diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index b1654c3cfc..01386b2a6e 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -849,7 +849,16 @@ def get_bprop_lstm(self): @bprop_getters.register(P.DynamicRNN) def get_bprop_dynamic_rnn(self): """Grad definition for `DynamicRNN` operation.""" - dynamic_rnn_grad = G.DynamicRNNGrad(forget_bias=self.forget_bias) + dynamic_rnn_grad = G.DynamicRNNGrad(cell_type=self.cell_type, + direction=self.direction, + cell_depth=self.cell_depth, + use_peephole=self.use_peephole, + keep_prob=self.keep_prob, + cell_clip=self.cell_clip, + num_proj=self.num_proj, + time_major=self.time_major, + forget_bias=self.forget_bias) + expand_dims = P.ExpandDims() def bprop(x, w, b, seq_length, init_h, init_c, out, dout): dy, dh, dc, _, _, _, _, _, = dout @@ -858,10 +867,30 @@ def get_bprop_dynamic_rnn(self): y, h, c, i, j, f, o, tanhct = out dw, db, dx, dh_prev, dc_prev = dynamic_rnn_grad(x, w, b, y, init_h[0], init_c[0], h, c, dy, dh, dc, i, j, f, o, tanhct) + dh_prev = expand_dims(dh_prev, 0) + dc_prev = expand_dims(dc_prev, 0) return dx, dw, db, (0), dh_prev, dc_prev return bprop +@bprop_getters.register(inner.DynamicGRUV2) +def get_bprop_dynamic_gru_v2(self): + """Grad definition for `DynamicGRUV2` operation.""" + dynamic_gru_v2_grad = G.DynamicGRUV2Grad(self.direction, self.cell_depth, self.keep_prob, self.cell_clip, + self.num_proj, self.time_major, 'double_bias', self.gate_order, + self.reset_after) + + def bprop(x, winput, whidden, binput, bhidden, seq, init_h, out, dout): + y, out_h, update, reset, new, hidden_new = out + dy, dout_h, _, _, _, _ = dout + + dw_input, dw_hidden, db_input, db_hidden, dx, dh_prev = dynamic_gru_v2_grad(x, winput, whidden, y, init_h, + out_h, dy, dout_h[-1], update, + reset, new, hidden_new, None, None) + return dx, dw_input, dw_hidden, db_input, db_hidden, (0), dh_prev + return bprop + + @bprop_getters.register(P.SigmoidCrossEntropyWithLogits) def get_bprop_sigmoid_crossentropy_with_logits(self): """Grad definition for `SigmoidCrossEntropyWithLogits` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 7885ee8bfb..a27f14412a 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -286,6 +286,8 @@ from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe from .dynamic_rnn import _dynamic_rnn_tbe +from .dynamic_gru_v2 import _dynamic_gru_v2_tbe +from .gru_v2_hidden_grad import _gru_v2_hidden_grad_tbe from .lstm_input_grad import _lstm_input_grad_tbe from .confusion_matrix import _confusion_matrix_tbe from .broadcast_to import _broadcast_to_tbe diff --git a/mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py b/mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py new file mode 100644 index 0000000000..b3d2c99392 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py @@ -0,0 +1,63 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""DynamicGRUV2 op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +dynamic_gru_v2_op_info = TBERegOp("DynamicGRUV2") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("dynamic_gru_v2.so") \ + .compute_cost(10) \ + .kernel_name("dynamic_gru_v2") \ + .attr("direction", "optional", "str", "all", "UNIDIRECTIONAL") \ + .attr("cell_depth", "optional", "int", "all", "1") \ + .attr("keep_prob", "optional", "float", "all", "1") \ + .attr("cell_clip", "optional", "float", "all", "-1") \ + .attr("num_proj", "optional", "int", "all", "0") \ + .attr("time_major", "optional", "bool", "all", "true") \ + .attr("activation", "optional", "str", "all", "tanh") \ + .attr("gate_order", "optional", "str", "all", "rzh") \ + .attr("reset_after", "optional", "bool", "all", "true") \ + .attr("is_training", "optional", "bool", "all", "true") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "weight_input", False, "required", "all") \ + .input(2, "weight_hidden", False, "required", "all") \ + .input(3, "bias_input", False, "optional", "all") \ + .input(4, "bias_hidden", False, "optional", "all") \ + .input(5, "seq_length", False, "optional", "all") \ + .input(6, "init_h", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .output(1, "output_h", False, "required", "all") \ + .output(2, "update", False, "optional", "all") \ + .output(3, "reset", False, "optional", "all") \ + .output(4, "new", False, "optional", "all") \ + .output(5, "hidden_new", False, "optional", "all") \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default, + DataType.F32_Default, DataType.I32_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, + DataType.F16_Default, DataType.I32_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ) \ + .get_op_info() + + +@op_info_register(dynamic_gru_v2_op_info) +def _dynamic_gru_v2_tbe(): + """DynamicGRUV2 TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py b/mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py new file mode 100644 index 0000000000..99d04758f0 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py @@ -0,0 +1,51 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""GRUV2HiddenGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +gru_v2_hidden_grad_op_info = TBERegOp("GRUV2HiddenGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("gru_v2_hidden_grad.so") \ + .compute_cost(10) \ + .kernel_name("gru_v2_hidden_grad") \ + .attr("gate_order", "optional", "str", "all", "zrh") \ + .partial_flag(True) \ + .input(0, "weight_input", False, "required", "all") \ + .input(1, "init_h", False, "required", "all") \ + .input(2, "h", False, "required", "all") \ + .input(3, "dy", False, "optional", "all") \ + .input(4, "dh", False, "optional", "all") \ + .input(5, "update", False, "optional", "all") \ + .input(6, "reset", False, "optional", "all") \ + .input(7, "new", False, "optional", "all") \ + .input(8, "hidden_new", False, "optional", "all") \ + .output(0, "dh_preh", False, "required", "all") \ + .output(1, "dgate_h", False, "required", "all") \ + .output(2, "dnt_x", False, "optional", "all") \ + .dtype_format(DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, + DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, + DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ + .get_op_info() + + +@op_info_register(gru_v2_hidden_grad_op_info) +def _gru_v2_hidden_grad_tbe(): + """DynamicGRUV2 TBE register""" + return diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index c7406bdb80..fc58c2c603 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1095,9 +1095,9 @@ class DynamicRNNGrad(PrimitiveWithInfer): def __init__(self, cell_type='LSTM', direction='UNIDIRECTIONAL', - cell_depth=0, + cell_depth=1, use_peephole=False, - keep_prob=-1.0, + keep_prob=1.0, cell_clip=-1.0, num_proj=0, time_major=True, @@ -1135,6 +1135,147 @@ class DynamicRNNGrad(PrimitiveWithInfer): return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype +class DynamicGRUV2Grad(PrimitiveWithInfer): + r""" + Computes the input gradients of DynamicGRUV2. + + Args: + direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. + Only 'UNIDIRECTIONAL' is currently supported. + cell_depth (int): An integer identifying the cell depth in the op. Default: 1. + keep_prob (float): A float identifying the keep prob in the op. Default: 1.0. + cell_clip (float): A float identifying the cell clip in the op. Default: -1.0. + num_proj (int): An integer identifying the num proj in the op. Default: 0. + time_major (bool): A bool identifying the time major in the op. Default: True. + bias_type (str): An string identifying the type of bias_type function in the op. Default to "double_bias". + gate_order (str): An string identifying the gate order in weight and bias. Default: 'rzh. + 'zrh' is another option. + reset_after (bool): An bool identifying whether to apply reset gate after matrix multiplication. Default: True. + + Inputs: + - **x** (Tensor) - Current words. Tensor of shape :math:`({num_step, batch_size, input_size)`. + The data type must be float16 or float32. + - **weight_input** (Tensor) - Weight. Tensor of shape :math:`(input_size, 3 x hidden_size)`. + The data type must be float16 or float32. + - **weight_hidden** (Tensor) - Bias. Tensor of shape :math:`(hidden_size, 3 x hidden_size)`. + The data type must be float16 or float32. + - **y** (Tensor) - A Tensor of shape :math: + if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`, + if num_proj == 0 `(num_step, batch_size, hidden_size)`. + The data type must be float16 or float32. + - **init_h** (Tensor) - Hidden state of initial time. + Tensor of shape :math:`(batch_size, hidden_size)`, or None. + The data type must be float16 or float32. + - **h** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`. + The data type must be float16 or float32. + - **dy** (Tensor) - Gradient of `y`, has the same shape and data type as `y`. + - **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `h`. + - **update** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`. + The data type must be float16 or float32. + - **reset** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`. + The data type must be float16 or float32. + - **new** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`. + The data type must be float16 or float32. + - **hidden_new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. + The data type must be float16 or float32. + - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(batch_size)`. + Only `None` is currently supported. + - **mask** (Tensor) - A 4-D Tensor. The data type must be float16 or float32. + + Outputs: + - **dw_input** (Tensor) - A Tensor has the same shape as `weight_input`. + Has the same type with input `x`. + - **dw_hidden** (Tensor) - A Tensor has the same shape as `weight_hidden`. + Has the same type with input `x`. + - **db_input** (Tensor) - A Tensor of shape :math:`(3 x hidden_size)`. + Has the same type with input `x`. + - **db_hidden** (Tensor) - A Tensor of shape :math:`(3 x hidden_size)`. + Has the same type with input `x`. + - **dx** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. + Has the same type with input `x`. + - **dh_prev** (Tensor) - A Tensor of shape :math:`(batch_size, hidden_size)`. + Has the same type with input `x`. + """ + + @prim_attr_register + def __init__(self, + direction='UNIDIRECTIONAL', + cell_depth=1, + keep_prob=1.0, + cell_clip=-1.0, + num_proj=0, + time_major=True, + bias_type="double_bias", + gate_order="zrh", + reset_after=True): + self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name) + self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) + self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name) + self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name) + self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) + self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) + self.bias_type = validator.check_string(bias_type, + ['no_bias', 'single_bias', 'double_bias'], "bias_type", self.name) + self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) + self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) + self.add_prim_attr("io_format", "ND") + + def infer_shape(self, x_shape, winput_shape, whidden_shape, y_shape, init_h_shape, h_shape, + dy_shape, dh_shape, update_shape, reset_shape, new_shape, hnew_shape, seq_shape, mask_shape): + validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) + validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name) + validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name) + validator.check_int(len(y_shape), 3, Rel.EQ, "y shape rank", self.name) + num_step, batch_size, input_size = x_shape + hidden_size = whidden_shape[0] + validator.check("weight_hidden_shape[-1]", whidden_shape[-1], "3 * hidden_size", + 3 * hidden_size, Rel.EQ, self.name) + validator.check("weight_input_shape", winput_shape, "excepted shape", + [input_size, 3 * hidden_size], Rel.EQ, self.name) + if self.num_proj > 0: + valid_y_shape = [num_step, batch_size, min(hidden_size, self.num_proj)] + else: + valid_y_shape = [num_step, batch_size, hidden_size] + validator.check("y_shape", y_shape, "excepted shape", valid_y_shape, Rel.EQ, self.name) + + validator.check("init_h_shape", init_h_shape, "excepted shape", + [batch_size, hidden_size], Rel.EQ, self.name) + valid_shape = [num_step, batch_size, hidden_size] + validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("dh_shape", dh_shape, "excepted shape", + [batch_size, hidden_size], Rel.EQ, self.name) + validator.check("update_shape", update_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("reset_shape", reset_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("new_shape", new_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + validator.check("hnew_shape", hnew_shape, "excepted shape", valid_shape, Rel.EQ, self.name) + if seq_shape is not None: + validator.check("seq_shape", seq_shape, "batch_size", batch_size, Rel.EQ, self.name) + + dx_shape = (num_step, batch_size, input_size) + dh_shape = (batch_size, hidden_size) + dwinput_shape = (input_size, 3 * hidden_size) + dwhidden_shape = (hidden_size, 3 * hidden_size) + db_shape = (3 * hidden_size,) + return dwinput_shape, dwhidden_shape, db_shape, db_shape, dx_shape, dh_shape + + def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, y_dtype, init_h_dtype, h_dtype, + dy_dtype, dh_dtype, update_dtype, reset_dtype, new_dtype, hnew_dtype, seq_dtype, mask_dtype): + valid_types = (mstype.float16, mstype.float32) + args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype, + "dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype, + "reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype} + validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"winput_dtype": winput_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"whidden_dtype": whidden_dtype}, valid_types, self.name) + validator.check_tensor_type_same(args, valid_types, self.name) + if seq_dtype is not None: + validator.check_tensor_type_same({"seq_dtype": seq_dtype}, (mstype.float32, mstype.float16), self.name) + if mask_dtype is not None: + validator.check_tensor_type_same({"mask_dtype": mask_dtype}, (mstype.float32, mstype.float16), self.name) + return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype + + class PReLUGrad(PrimitiveWithInfer): r""" Gradients of PReLU operation. diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 17382bdd52..5a4e36dd23 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -451,6 +451,157 @@ class MatrixSetDiag(PrimitiveWithInfer): return assist_shape +class DynamicGRUV2(PrimitiveWithInfer): + r""" + DynamicGRUV2 Operator. + + Args: + direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. + Only 'UNIDIRECTIONAL' is currently supported. + cell_depth (int): An integer identifying the cell depth in the op. Default: 1. + keep_prob (float): A float identifying the keep prob in the op. Default: 1.0. + cell_clip (float): A float identifying the cell clip in the op. Default: -1.0. + num_proj (int): An integer identifying the num proj in the op. Default: 0. + time_major (bool): A bool identifying the time major in the op. Default: True. + activation (str) : A string identifying the type of activation function in the op. Default: 'tanh'. + Only 'tanh' is currently supported. + gate_order (str): A string identifying the gate order in weight and bias. Default: 'rzh. + 'zrh' is another option. + reset_after (bool): A bool identifying whether to apply reset gate after matrix multiplication. Default: True. + is_training (bool): A bool identifying is training in the op. Default: True. + + Inputs: + - **x** (Tensor) - Current words. + Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{input_size})`. + The data type must be float16. + - **weight_input** (Tensor) - Input-hidden weight. + Tensor of shape :math:`(\text{input_size}, 3 \times \text{hidden_size})`. + The data type must be float16. + - **weight_hidden** (Tensor) - Hidden-hidden weight. + Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`. + The data type must be float16. + - **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. + The data type must be float16 or float32. + - **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. + The data type must be float16 or float32. + - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. + Only `None` is currently supported. + - **init_h** (Tensor) - Hidden state of initial time. + Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`, or None. + The data type must be float16 or float32. + + Outputs: + - **y** (Tensor) - A Tensor of shape :math: + if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`, + if num_proj == 0 `(num_step, batch_size, hidden_size)`. + Has the same data type with input `bais_type`. + - **output_h** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. + Has the same data type with input `bais_type`. + - **update** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. + Has the same data type with input `bais_type`. + - **reset** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. + Has the same data type with input `bais_type`. + - **new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. + Has the same data type with input `bais_type`. + - **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. + Has the same data type with input `bais_type`. + + - If `bias_input`, `bias_hidden` and `init_h` all are `None`, `bias_type` is float32. + - If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`. + - If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`. + - Otherwise, `bias_type` is the date type of `init_h`. + + Examples: + >>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16)) + >>> weight_i = Tensor(np.random.rand(64, 48).astype(np.float16)) + >>> weight_h = Tensor(np.random.rand(16, 48).astype(np.float16)) + >>> bias_i = Tensor(np.random.rand(48).astype(np.float16)) + >>> bias_h = Tensor(np.random.rand(48).astype(np.float16)) + >>> init_h = Tensor(np.random.rand(8, 16).astype(np.float16)) + >>> dynamic_gru_v2 = P.DynamicGRUV2() + >>> output = dynamic_gru_v2(x, weight_i, weight_h, bias_i, bias_h, None, init_h) + >>> output[0].shape + (2, 8, 16) + """ + + @prim_attr_register + def __init__(self, + direction='UNIDIRECTIONAL', + cell_depth=1, + keep_prob=1.0, + cell_clip=-1.0, + num_proj=0, + time_major=True, + activation="tanh", + gate_order="rzh", + reset_after=True, + is_training=True): + self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name) + self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) + self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name) + self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name) + self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) + self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name) + self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) + self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) + self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) + self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) + self.add_prim_attr("io_format", "ND") + + def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape): + validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) + validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name) + validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name) + if binput_shape is not None: + validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name) + if bhidden_shape is not None: + validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name) + if h_shape is not None: + validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name) + if seq_shape is not None: + raise ValueError(f"For {self.name}, seq_shape should be None.") + + num_step, batch_size, input_size = x_shape + hidden_size = winput_shape[-1] // 3 + + if winput_shape[-1] % 3 != 0: + raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.") + + validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]", + whidden_shape[-1], Rel.EQ, self.name) + validator.check("bias_input_shape", binput_shape, "bias_hidden_shape", bhidden_shape, Rel.EQ, self.name) + validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name) + validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name) + if h_shape is not None: + validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name) + validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name) + if self.num_proj > 0: + y_shape = (num_step, batch_size, min(hidden_size, self.num_proj)) + else: + y_shape = (num_step, batch_size, hidden_size) + outh_shape = (num_step, batch_size, hidden_size) + return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape + + def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): + validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name) + validator.check_tensor_type_same({"weight input dtype": winput_dtype}, (mstype.float16,), self.name) + validator.check_tensor_type_same({"weight hidden dtype": whidden_dtype}, (mstype.float16,), self.name) + b_dtype = mstype.float32 + if binput_dtype is not None: + validator.check_tensor_type_same({"bias input dtype": binput_dtype}, + (mstype.float16, mstype.float32), self.name) + b_dtype = binput_dtype + elif bhidden_dtype is not None: + validator.check_tensor_type_same({"bias hidden dtype": bhidden_dtype}, + (mstype.float16, mstype.float32), self.name) + b_dtype = bhidden_dtype + elif h_dtype is not None: + validator.check_tensor_type_same({"init_h dtype": h_dtype}, + (mstype.float16, mstype.float32), self.name) + b_dtype = h_dtype + return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype + + class ConfusionMulGrad(PrimitiveWithInfer): """ `output0` is the dot product result of input0 and input1. diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index a50f04efff..68d9154a7f 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -5611,33 +5611,35 @@ class DynamicRNN(PrimitiveWithInfer): DynamicRNN Operator. Args: - cell_type (str): An string identifying the cell type in the op. Default: 'LSTM'. + cell_type (str): A string identifying the cell type in the op. Default: 'LSTM'. Only 'LSTM' is currently supported. - direction (str): An string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. + direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. Only 'UNIDIRECTIONAL' is currently supported. cell_depth (int): An integer identifying the cell depth in the op. Default: 1. - use_peephole (bool): An bool identifying if use peephole in the op. Default: False. - keep_prob (float): An float identifying the keep prob in the op. Default: 1.0. - cell_clip (float): An float identifying the cell clip in the op. Default: -1.0. + use_peephole (bool): A bool identifying if use peephole in the op. Default: False. + keep_prob (float): A float identifying the keep prob in the op. Default: 1.0. + cell_clip (float): A float identifying the cell clip in the op. Default: -1.0. num_proj (int): An integer identifying the num proj in the op. Default: 0. - time_major (bool): An bool identifying the time major in the op. Default: True. + time_major (bool): A bool identifying the time major in the op. Default: True. Only `True` is currently supported. - activation (str): An string identifying the type of activation function in the op. Default: 'tanh'. + activation (str): A string identifying the type of activation function in the op. Default: 'tanh'. Only 'tanh' is currently supported. - forget_bias (float): An float identifying the forget bias in the op. Default: 0.0. - is_training (bool): An bool identifying is training in the op. Default: True. + forget_bias (float): A float identifying the forget bias in the op. Default: 0.0. + is_training (bool): A bool identifying is training in the op. Default: True. Inputs: - **x** (Tensor) - Current words. Tensor of shape (`num_step`, `batch_size`, `input_size`). - The data type must be float16 or float32. + The data type must be float16. - **w** (Tensor) - Weight. Tensor of shape (`input_size + hidden_size`, `4 x hidden_size`). - The data type must be float16 or float32. + The data type must be float16. - **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`). The data type must be float16 or float32. - **seq_length** (Tensor) - The length of each batch. Tensor of shape (`batch_size`). Only `None` is currently supported. - **init_h** (Tensor) - Hidden state of initial time. Tensor of shape (1, `batch_size`, `hidden_size`). + The data type must be float16. - **init_c** (Tensor) - Cell state of initial time. Tensor of shape (1, `batch_size`, `hidden_size`). + The data type must be float16. Outputs: - **y** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`). @@ -5664,7 +5666,9 @@ class DynamicRNN(PrimitiveWithInfer): >>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) >>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) >>> dynamic_rnn = P.DynamicRNN() - >>> output = lstm(x, w, b, None, init_h, init_c) + >>> output = dynamic_rnn(x, w, b, None, init_h, init_c) + >>> output[0].shape + (2, 16, 32) """ @prim_attr_register @@ -5684,7 +5688,7 @@ class DynamicRNN(PrimitiveWithInfer): self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name) self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name) - self.num_proj = validator.check_value_type("num_proj", num_proj, [int], self.name) + self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name) self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name) self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) @@ -5721,11 +5725,11 @@ class DynamicRNN(PrimitiveWithInfer): return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype): - validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float32, mstype.float16), self.name) - validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float32, mstype.float16), self.name) + validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name) + validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float16,), self.name) validator.check_tensor_type_same({"b dtype": b_dtype}, (mstype.float32, mstype.float16), self.name) - validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float32, mstype.float16), self.name) - validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float32, mstype.float16), self.name) + validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float16,), self.name) + validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float16,), self.name) return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 946d644b5c..068be69000 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -817,6 +817,17 @@ class BasicLSTMCellNet(nn.Cell): return self.lstm(x, h, c, w, b) +class DynamicGRUV2Net(nn.Cell): + """ DynamicGRUV2Net definition """ + + def __init__(self): + super(DynamicGRUV2Net, self).__init__() + self.dynamic_gru = inner.DynamicGRUV2() + + def construct(self, x, w_i, w_h, b_i, b_h, init_h): + return self.dynamic_gru(x, w_i, w_h, b_i, b_h, None, init_h) + + class EditDistance(nn.Cell): def __init__(self, hypothesis_shape, truth_shape, normalize=True): super(EditDistance, self).__init__() @@ -2508,6 +2519,19 @@ test_case_other_ops = [ Tensor(np.random.rand(1, 64).astype(np.float16)), Tensor(np.random.rand(1, 64).astype(np.float16)), Tensor(np.random.rand(1, 64).astype(np.float16))]}), + ('DynamicGRUV2Net', { + 'block': DynamicGRUV2Net(), + 'desc_inputs': [Tensor(np.random.rand(2, 8, 64).astype(np.float16)), + Tensor(np.random.rand(64, 48).astype(np.float16)), + Tensor(np.random.rand(16, 48).astype(np.float16)), + Tensor(np.random.rand(48).astype(np.float16)), + Tensor(np.random.rand(48).astype(np.float16)), + Tensor(np.random.rand(8, 16).astype(np.float16))], + 'desc_bprop': [Tensor(np.random.rand(2, 8, 16).astype(np.float16)), + Tensor(np.random.rand(2, 8, 16).astype(np.float16)), + Tensor(np.random.rand(2, 8, 16).astype(np.float16)), + Tensor(np.random.rand(2, 8, 16).astype(np.float16)), + Tensor(np.random.rand(2, 8, 16).astype(np.float16))]}), ] test_case_quant_ops = [