diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 8759e549c2..a4ceb8ec94 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -20,6 +20,7 @@ #include #include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h" +#include "backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.h" #include "backend/optimizer/ascend/ir_fission/bn_split.h" #include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" #include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h" @@ -280,6 +281,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); AddAscendIRFusionRulesPass(ir_fusion_pm.get()); AddAscendIRFusionPass(ir_fusion_pm.get()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.cc new file mode 100644 index 0000000000..269a45ff11 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.cc @@ -0,0 +1,344 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.h" +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { + +namespace { +constexpr size_t kDynamicGRUV2GradInputNum = 12; +constexpr size_t kDynamicGRUV2GradOutputNum = 6; +constexpr size_t kSplitVOutputNum = 2; +constexpr size_t kGRUV2HiddenGradOutputNum = 3; + +AnfNodePtr CreateGRUV2HiddenGradNode(const FuncGraphPtr &graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + const auto &dynamic_gru_v2_grad_inputs = cnode->inputs(); + std::vector gru_v2_hidden_grad_inputs = { + NewValueNode(std::make_shared(kGRUV2HiddenGradOpName)), + dynamic_gru_v2_grad_inputs[3], + dynamic_gru_v2_grad_inputs[5], + dynamic_gru_v2_grad_inputs[6], + dynamic_gru_v2_grad_inputs[7], + dynamic_gru_v2_grad_inputs[8], + dynamic_gru_v2_grad_inputs[9], + dynamic_gru_v2_grad_inputs[10], + dynamic_gru_v2_grad_inputs[11], + dynamic_gru_v2_grad_inputs[12]}; + + std::vector ori_outputs; + CreateMultipleOutputsOfAnfNode(graph, node, kDynamicGRUV2GradOutputNum, &ori_outputs); + auto gru_v2_hidden_grad_op = graph->NewCNode(gru_v2_hidden_grad_inputs); + MS_EXCEPTION_IF_NULL(gru_v2_hidden_grad_op); + auto h_dtype = AnfAlgo::GetOutputInferDataType(dynamic_gru_v2_grad_inputs[6], 0); + auto types = {h_dtype, h_dtype, h_dtype}; + std::vector dh_preh_shape = AnfAlgo::GetOutputInferShape(ori_outputs[5], 0); + std::vector dgate_h_shape = {AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[6], 0)[0], + AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[6], 0)[1], + 3 * AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[6], 0)[2]}; + std::vector dnx_t_shape = AnfAlgo::GetOutputInferShape(dynamic_gru_v2_grad_inputs[6], 0); + auto shapes = {dh_preh_shape, dgate_h_shape, dnx_t_shape}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, gru_v2_hidden_grad_op.get()); + auto gate_order = AnfAlgo::GetNodeAttr(cnode, "gate_order"); + AnfAlgo::SetNodeAttr("gate_order", MakeValue(gate_order), gru_v2_hidden_grad_op); + return gru_v2_hidden_grad_op; +} + +AnfNodePtr CreateHSplitVDNode(const FuncGraphPtr &graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + // SplitV + std::vector splitvd_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name())), node}; + auto split_vd = graph->NewCNode(splitvd_input); + MS_EXCEPTION_IF_NULL(split_vd); + auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)}; + size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[0]; + size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[1]; + size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[2]; + std::vector shape = {t_size - IntToSize(1), batch, hidden_size}; + std::vector shape2 = {IntToSize(1), batch, hidden_size}; + std::vector> shapes = {shape, shape2}; + AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get()); + AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(0)), split_vd); + AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToLong(2)), split_vd); + std::vector size_splits = {SizeToLong(t_size - 1), SizeToLong(1)}; + AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split_vd); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_vd); + return split_vd; +} + +AnfNodePtr CreateHReshape(const FuncGraphPtr &graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto ori_shape = AnfAlgo::GetOutputInferShape(node, 0); + std::vector> shape_tmp; + if (ori_shape.size() == 3) { + shape_tmp = {ori_shape}; + } else { + shape_tmp = {{IntToSize(1), ori_shape[0], ori_shape[1]}}; + } + auto ori_dtype = {AnfAlgo::GetOutputInferDataType(node, 0)}; + // reshape + std::vector reshape_input = {NewValueNode(std::make_shared(prim::kPrimReshape->name())), node}; + auto reshape = graph->NewCNode(reshape_input); + AnfAlgo::SetOutputInferTypeAndShape(ori_dtype, shape_tmp, reshape.get()); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reshape); + return reshape; +} + +AnfNodePtr CreateHConcatDNode(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node1); + MS_EXCEPTION_IF_NULL(node2); + std::vector ori_outputs; + CreateMultipleOutputsOfAnfNode(graph, node2, 2, &ori_outputs); + auto reshape = CreateHReshape(graph, node1); + + std::vector concat_inputs = {NewValueNode(std::make_shared(prim::kPrimConcat->name())), + reshape, ori_outputs[0]}; + auto concat_op = graph->NewCNode(concat_inputs); + MS_EXCEPTION_IF_NULL(concat_op); + + std::vector shape = {AnfAlgo::GetOutputInferShape(node2, 0)[0] + 1, AnfAlgo::GetOutputInferShape(node2, 0)[1], + AnfAlgo::GetOutputInferShape(node2, 0)[2]}; + auto types = {AnfAlgo::GetOutputInferDataType(node2, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, concat_op.get()); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat_op); + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector{2}), concat_op); + AnfAlgo::SetNodeAttr("axis", MakeValue(SizeToLong(0)), concat_op); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op); + return concat_op; +} + +AnfNodePtr CreateDgateHSplitVDNode(const FuncGraphPtr &graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + // SplitV + std::vector splitvd_input = {NewValueNode(std::make_shared(prim::kPrimSplitV->name())), node}; + auto split_vd = graph->NewCNode(splitvd_input); + MS_EXCEPTION_IF_NULL(split_vd); + auto dtypes = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(node, 0)}; + size_t t_size = AnfAlgo::GetOutputInferShape(node, 0)[0]; + size_t batch = AnfAlgo::GetOutputInferShape(node, 0)[1]; + size_t hidden_size = AnfAlgo::GetOutputInferShape(node, 0)[2] / 3; + std::vector shape = {t_size, batch, 2 * hidden_size}; + std::vector shape2 = {t_size, batch, hidden_size}; + std::vector> shapes = {shape, shape2}; + AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_vd.get()); + AnfAlgo::SetNodeAttr("split_dim", MakeValue(SizeToLong(2)), split_vd); + AnfAlgo::SetNodeAttr("num_split", MakeValue(SizeToLong(2)), split_vd); + std::vector size_splits = {2 * SizeToLong(hidden_size), SizeToLong(hidden_size)}; + AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split_vd); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_vd); + return split_vd; +} + +AnfNodePtr CreateDgateXConcatDNode(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(graph); + // node1: dgate_h_split + // node2: dnt_x + MS_EXCEPTION_IF_NULL(node1); + MS_EXCEPTION_IF_NULL(node2); + std::vector ori_outputs; + CreateMultipleOutputsOfAnfNode(graph, node1, 2, &ori_outputs); + + // ConcatD + std::vector concat_inputs = {NewValueNode(std::make_shared(prim::kPrimConcat->name())), + ori_outputs[0], node2}; + auto concat_op = graph->NewCNode(concat_inputs); + MS_EXCEPTION_IF_NULL(concat_op); + std::vector shape = {AnfAlgo::GetOutputInferShape(node2, 0)[0], AnfAlgo::GetOutputInferShape(node2, 0)[1], + AnfAlgo::GetOutputInferShape(node1, 0)[2] + AnfAlgo::GetOutputInferShape(node2, 0)[2]}; + auto types = {AnfAlgo::GetOutputInferDataType(node2, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, concat_op.get()); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(2)), concat_op); + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector{2}), concat_op); + AnfAlgo::SetNodeAttr("axis", MakeValue(SizeToLong(2)), concat_op); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat_op); + return concat_op; +} + +AnfNodePtr CreateWBroadcastToDNode(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(graph); + // node1 : input node + // node2 : orign_input x + MS_EXCEPTION_IF_NULL(node1); + MS_EXCEPTION_IF_NULL(node2); + // BroadcastTo + std::vector braodcast_to_input = {NewValueNode(std::make_shared(kBroadcastToOpName)), node1}; + auto broadcast_to_d = graph->NewCNode(braodcast_to_input); + MS_EXCEPTION_IF_NULL(broadcast_to_d); + size_t t_size = AnfAlgo::GetOutputInferShape(node2, 0)[0]; + size_t batch = AnfAlgo::GetOutputInferShape(node1, 0)[0]; + size_t gate_size = AnfAlgo::GetOutputInferShape(node1, 0)[1]; + std::vector shape = {t_size, batch, gate_size}; + auto type = {AnfAlgo::GetOutputInferDataType(node1, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(type, {shape}, broadcast_to_d.get()); + + std::vector attr_shape = {SizeToLong(t_size), SizeToLong(batch), SizeToLong(gate_size)}; + AnfAlgo::SetNodeAttr("shape", MakeValue(attr_shape), broadcast_to_d); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), broadcast_to_d); + return broadcast_to_d; +} + +AnfNodePtr CreateDhxBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node1); + MS_EXCEPTION_IF_NULL(node2); + // BatchMatMul + std::vector matmul_inputs = {NewValueNode(std::make_shared(prim::kPrimBatchMatMul->name())), + node1, node2}; + auto batch_matmul = graph->NewCNode(matmul_inputs); + MS_EXCEPTION_IF_NULL(batch_matmul); + std::vector shape = {AnfAlgo::GetOutputInferShape(node1, 0)[0], AnfAlgo::GetOutputInferShape(node1, 0)[2], + AnfAlgo::GetOutputInferShape(node2, 0)[2]}; + AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {shape}, batch_matmul.get()); + AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul); + AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul); + return batch_matmul; +} + +AnfNodePtr CreateDwhBatchMatMul(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node1); + MS_EXCEPTION_IF_NULL(node2); + // BatchMatMul + std::vector matmul_inputs = {NewValueNode(std::make_shared(prim::kPrimBatchMatMul->name())), + node1, node2}; + auto batch_matmul = graph->NewCNode(matmul_inputs); + MS_EXCEPTION_IF_NULL(batch_matmul); + std::vector shape = {AnfAlgo::GetOutputInferShape(node1, 0)[0], AnfAlgo::GetOutputInferShape(node1, 0)[1], + AnfAlgo::GetOutputInferShape(node2, 0)[1]}; + AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {shape}, batch_matmul.get()); + AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), batch_matmul); + AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), batch_matmul); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul); + return batch_matmul; +} + +AnfNodePtr CreateDwReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + // ReduceSumD for dw_x and dw_h + std::vector reducesum_inputs = {NewValueNode(std::make_shared(prim::kPrimReduceSum->name())), + node}; + auto reduce_sumd = graph->NewCNode(reducesum_inputs); + MS_EXCEPTION_IF_NULL(reduce_sumd); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node2, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, reduce_sumd.get()); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector{0}), reduce_sumd); + AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd); + return reduce_sumd; +} + +AnfNodePtr CreateDbReduceSumDNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node2); + // ReduceSumD for db_x and db_h + std::vector reducesum_inputs = {NewValueNode(std::make_shared(prim::kPrimReduceSum->name())), + node}; + auto reduce_sumd = graph->NewCNode(reducesum_inputs); + MS_EXCEPTION_IF_NULL(reduce_sumd); + + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + std::vector shape = {3 * AnfAlgo::GetOutputInferShape(node2, 0)[1]}; + AnfAlgo::SetOutputInferTypeAndShape(types, {shape}, reduce_sumd.get()); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector{0, 1}), reduce_sumd); + AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_sumd); + AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sumd); + return reduce_sumd; +} +} // namespace + +const BaseRef DynamicGRUV2GradFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimDynamicGRUV2Grad, Xs}); +} + +const AnfNodePtr DynamicGRUV2GradFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto dynamic_gru_v2_grad_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(dynamic_gru_v2_grad_cnode); + if (dynamic_gru_v2_grad_cnode->size() < kDynamicGRUV2GradInputNum + 1) { + MS_LOG(INFO) << "The node " << dynamic_gru_v2_grad_cnode->DebugString() << " has less than " + << kDynamicGRUV2GradInputNum << " inputs"; + return nullptr; + } + + // input_list of dynamic_gru_v2_grad + const auto &ori_inputs = dynamic_gru_v2_grad_cnode->inputs(); + // add gru_v2_gru_hidden + auto gru_v2_gru_hidden = CreateGRUV2HiddenGradNode(func_graph, dynamic_gru_v2_grad_cnode); + std::vector gru_hidden_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, gru_v2_gru_hidden, kGRUV2HiddenGradOutputNum, &gru_hidden_outputs); + size_t step_num = AnfAlgo::GetOutputInferShape(ori_inputs[1], 0)[0]; + AnfNodePtr dwh_batch_matmul = nullptr; + if (step_num != 1) { + // split h + auto h_split = CreateHSplitVDNode(func_graph, ori_inputs[6]); + // concat(h, h_split) + auto h_concat = CreateHConcatDNode(func_graph, ori_inputs[5], h_split); + // batchmatmul(h_concat.T, dgate_h) + dwh_batch_matmul = CreateDhxBatchMatMul(func_graph, h_concat, gru_hidden_outputs[1]); + } else { + auto reshape = CreateHReshape(func_graph, ori_inputs[5]); + // batchmatmul(init_h.T, dgate_h) + dwh_batch_matmul = CreateDhxBatchMatMul(func_graph, reshape, gru_hidden_outputs[1]); + } + // split dgate_h + auto dgate_h_split = CreateDgateHSplitVDNode(func_graph, gru_hidden_outputs[1]); + // concat(dgate_h_split[0], dnt_x) to dgate_x + auto dgate_x_concat = CreateDgateXConcatDNode(func_graph, dgate_h_split, gru_hidden_outputs[2]); + // broadcast weight_input [input_size, 3 * hidden_size] to [t_size, input_size, 3 * hidden_size] + auto w_input_broadcast = CreateWBroadcastToDNode(func_graph, ori_inputs[2], ori_inputs[1]); + // batchmatmul(x.T, dgate_x_concat) + auto dwx_batch_matmul = CreateDhxBatchMatMul(func_graph, ori_inputs[1], dgate_x_concat); + // batchmatmul(dgate_x_concat, w_input_broadcast.T) + auto dxt_batch_matmul = CreateDwhBatchMatMul(func_graph, dgate_x_concat, w_input_broadcast); + // reducesum dw_x and dw_h + auto dwx_reduce_sum = CreateDwReduceSumDNode(func_graph, dwx_batch_matmul, ori_inputs[2]); + auto dwh_reduce_sum = CreateDwReduceSumDNode(func_graph, dwh_batch_matmul, ori_inputs[3]); + // reducesum db_x and db_h + auto dbx_reduce_sum = CreateDbReduceSumDNode(func_graph, dgate_x_concat, ori_inputs[5]); + auto dbh_reduce_sum = CreateDbReduceSumDNode(func_graph, gru_hidden_outputs[1], ori_inputs[5]); + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), + dwx_reduce_sum, + dwh_reduce_sum, + dbx_reduce_sum, + dbh_reduce_sum, + dxt_batch_matmul, + gru_hidden_outputs[0]}; + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + return make_tuple; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.h new file mode 100644 index 0000000000..0fef961730 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_GRU_V2_GRAD_FISSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_GRU_V2_GRAD_FISSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class DynamicGRUV2GradFission : public PatternProcessPass { + public: + explicit DynamicGRUV2GradFission(bool multigraph = true) + : PatternProcessPass("dynamic_gru_grad_v2_fission", multigraph) {} + ~DynamicGRUV2GradFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_GRU_V2_GRAD_FISSION_H_ diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 05c18afab9..1330249993 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1157,7 +1157,7 @@ class DynamicGRUV2Grad(PrimitiveWithInfer): reset_after (bool): An bool identifying whether to apply reset gate after matrix multiplication. Default: True. Inputs: - - **x** (Tensor) - Current words. Tensor of shape :math:`({num_step, batch_size, input_size)`. + - **x** (Tensor) - Current words. Tensor of shape :math:`(num_step, batch_size, input_size)`. The data type must be float16 or float32. - **weight_input** (Tensor) - Weight. Tensor of shape :math:`(input_size, 3 x hidden_size)`. The data type must be float16 or float32. @@ -1168,17 +1168,17 @@ class DynamicGRUV2Grad(PrimitiveWithInfer): if num_proj == 0 `(num_step, batch_size, hidden_size)`. The data type must be float16 or float32. - **init_h** (Tensor) - Hidden state of initial time. - Tensor of shape :math:`(batch_size, hidden_size)`, or None. + Tensor of shape :math:`(batch_size, hidden_size)`. The data type must be float16 or float32. - - **h** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`. + - **h** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. The data type must be float16 or float32. - **dy** (Tensor) - Gradient of `y`, has the same shape and data type as `y`. - - **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `h`. - - **update** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`. + - **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `init_h`. + - **update** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. The data type must be float16 or float32. - - **reset** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`. + - **reset** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. The data type must be float16 or float32. - - **new** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`. + - **new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. The data type must be float16 or float32. - **hidden_new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. The data type must be float16 or float32. diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 952e36a837..cc87c4dd1a 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -492,7 +492,7 @@ class DynamicGRUV2(PrimitiveWithInfer): - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. Only `None` is currently supported. - **init_h** (Tensor) - Hidden state of initial time. - Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`, or None. + Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`. The data type must be float16 or float32. Outputs: @@ -511,10 +511,9 @@ class DynamicGRUV2(PrimitiveWithInfer): - **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. Has the same data type with input `bais_type`. - - If `bias_input`, `bias_hidden` and `init_h` all are `None`, `bias_type` is float32. + - If `bias_input` and `bias_hidden` both are `None`, `bias_type` is float32. - If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`. - If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`. - - Otherwise, `bias_type` is the date type of `init_h`. Examples: >>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16)) @@ -553,8 +552,7 @@ class DynamicGRUV2(PrimitiveWithInfer): self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) self.add_prim_attr("io_format", "ND") - def infer_shape(self, x_shape, winput_shape, whidden_shape, - binput_shape=None, bhidden_shape=None, seq_shape=None, h_shape=None): + def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape): validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name) validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name) @@ -564,7 +562,7 @@ class DynamicGRUV2(PrimitiveWithInfer): if winput_shape[-1] % 3 != 0: raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.") - self.placeholder_index = [3, 4, 5, 6] + self.placeholder_index = [3, 4, 5] if binput_shape is not None: validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name) validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name) @@ -574,14 +572,12 @@ class DynamicGRUV2(PrimitiveWithInfer): validator.check("bias_hidden_shape", bhidden_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name) self.placeholder_index.remove(4) - if h_shape is not None: - validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name) - validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name) - validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name) - self.placeholder_index.remove(6) if seq_shape is not None: raise ValueError(f"For {self.name}, seq_shape should be None.") + validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name) + validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name) + validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name) validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]", whidden_shape[-1], Rel.EQ, self.name) validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name) @@ -590,15 +586,15 @@ class DynamicGRUV2(PrimitiveWithInfer): y_shape = (num_step, batch_size, min(hidden_size, self.num_proj)) else: y_shape = (num_step, batch_size, hidden_size) - outh_shape = (num_step, batch_size, hidden_size) + out_shape = (num_step, batch_size, hidden_size) self.add_prim_attr("placeholder_index", self.placeholder_index) - return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape + return y_shape, out_shape, out_shape, out_shape, out_shape, out_shape - def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, - binput_dtype=None, bhidden_dtype=None, seq_dtype=None, h_dtype=None): + def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name) validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name) validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name) + validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name) b_dtype = mstype.float32 if binput_dtype is not None: validator.check_tensor_dtype_valid("bias input dtype", binput_dtype, @@ -608,10 +604,7 @@ class DynamicGRUV2(PrimitiveWithInfer): validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype, (mstype.float16, mstype.float32), self.name) b_dtype = bhidden_dtype - elif h_dtype is not None: - validator.check_tensor_dtype_valid("init_h dtype", h_dtype, - (mstype.float16, mstype.float32), self.name) - b_dtype = h_dtype + return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 00b6aabf1d..b4cdd45e12 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -2532,7 +2532,11 @@ test_case_other_ops = [ Tensor(np.random.rand(48).astype(np.float16)), Tensor(np.random.rand(48).astype(np.float16)), Tensor(np.random.rand(8, 16).astype(np.float16))], - 'skip': ['backward']}), + 'desc_bprop': [Tensor(np.random.rand(2, 8, 16).astype(np.float16)), + Tensor(np.random.rand(2, 8, 16).astype(np.float16)), + Tensor(np.random.rand(2, 8, 16).astype(np.float16)), + Tensor(np.random.rand(2, 8, 16).astype(np.float16)), + Tensor(np.random.rand(2, 8, 16).astype(np.float16))]}), ] test_case_quant_ops = [