!8656 Adapt DynamicGRUV2Grad for Ascend new backend.

From: @liu_xiao_93
Reviewed-by: @jjfeing,@liangchenghui
Signed-off-by: @liangchenghui
pull/8656/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit f40a4781e4

@ -20,6 +20,7 @@
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
#include "backend/optimizer/ascend/ir_fission/dynamic_gru_v2_grad_fission.h"
#include "backend/optimizer/ascend/ir_fission/bn_split.h"
#include "backend/optimizer/ascend/ir_fission/bn_grad_split.h"
#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h"
@ -280,6 +281,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>());
ir_fusion_pm->AddPass(std::make_shared<DynamicGRUV2GradFission>());
AddAscendIRFusionRulesPass(ir_fusion_pm.get());
AddAscendIRFusionPass(ir_fusion_pm.get());
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>());

@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_GRU_V2_GRAD_FISSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_GRU_V2_GRAD_FISSION_H_
#include <vector>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class DynamicGRUV2GradFission : public PatternProcessPass {
public:
explicit DynamicGRUV2GradFission(bool multigraph = true)
: PatternProcessPass("dynamic_gru_grad_v2_fission", multigraph) {}
~DynamicGRUV2GradFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DYNAMIC_GRU_V2_GRAD_FISSION_H_

@ -1157,7 +1157,7 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
reset_after (bool): An bool identifying whether to apply reset gate after matrix multiplication. Default: True.
Inputs:
- **x** (Tensor) - Current words. Tensor of shape :math:`({num_step, batch_size, input_size)`.
- **x** (Tensor) - Current words. Tensor of shape :math:`(num_step, batch_size, input_size)`.
The data type must be float16 or float32.
- **weight_input** (Tensor) - Weight. Tensor of shape :math:`(input_size, 3 x hidden_size)`.
The data type must be float16 or float32.
@ -1168,17 +1168,17 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
if num_proj == 0 `(num_step, batch_size, hidden_size)`.
The data type must be float16 or float32.
- **init_h** (Tensor) - Hidden state of initial time.
Tensor of shape :math:`(batch_size, hidden_size)`, or None.
Tensor of shape :math:`(batch_size, hidden_size)`.
The data type must be float16 or float32.
- **h** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`.
- **h** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
The data type must be float16 or float32.
- **dy** (Tensor) - Gradient of `y`, has the same shape and data type as `y`.
- **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `h`.
- **update** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`.
- **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `init_h`.
- **update** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
The data type must be float16 or float32.
- **reset** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`.
- **reset** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
The data type must be float16 or float32.
- **new** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`.
- **new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
The data type must be float16 or float32.
- **hidden_new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
The data type must be float16 or float32.

@ -492,7 +492,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`.
Only `None` is currently supported.
- **init_h** (Tensor) - Hidden state of initial time.
Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`, or None.
Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`.
The data type must be float16 or float32.
Outputs:
@ -511,10 +511,9 @@ class DynamicGRUV2(PrimitiveWithInfer):
- **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
Has the same data type with input `bais_type`.
- If `bias_input`, `bias_hidden` and `init_h` all are `None`, `bias_type` is float32.
- If `bias_input` and `bias_hidden` both are `None`, `bias_type` is float32.
- If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`.
- If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`.
- Otherwise, `bias_type` is the date type of `init_h`.
Examples:
>>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16))
@ -553,8 +552,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, winput_shape, whidden_shape,
binput_shape=None, bhidden_shape=None, seq_shape=None, h_shape=None):
def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape):
validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name)
validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name)
@ -564,7 +562,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
if winput_shape[-1] % 3 != 0:
raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.")
self.placeholder_index = [3, 4, 5, 6]
self.placeholder_index = [3, 4, 5]
if binput_shape is not None:
validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name)
validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
@ -574,14 +572,12 @@ class DynamicGRUV2(PrimitiveWithInfer):
validator.check("bias_hidden_shape", bhidden_shape,
"3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
self.placeholder_index.remove(4)
if h_shape is not None:
validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name)
validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name)
self.placeholder_index.remove(6)
if seq_shape is not None:
raise ValueError(f"For {self.name}, seq_shape should be None.")
validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name)
validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name)
validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]",
whidden_shape[-1], Rel.EQ, self.name)
validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name)
@ -590,15 +586,15 @@ class DynamicGRUV2(PrimitiveWithInfer):
y_shape = (num_step, batch_size, min(hidden_size, self.num_proj))
else:
y_shape = (num_step, batch_size, hidden_size)
outh_shape = (num_step, batch_size, hidden_size)
out_shape = (num_step, batch_size, hidden_size)
self.add_prim_attr("placeholder_index", self.placeholder_index)
return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape
return y_shape, out_shape, out_shape, out_shape, out_shape, out_shape
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype,
binput_dtype=None, bhidden_dtype=None, seq_dtype=None, h_dtype=None):
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype):
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name)
b_dtype = mstype.float32
if binput_dtype is not None:
validator.check_tensor_dtype_valid("bias input dtype", binput_dtype,
@ -608,10 +604,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype,
(mstype.float16, mstype.float32), self.name)
b_dtype = bhidden_dtype
elif h_dtype is not None:
validator.check_tensor_dtype_valid("init_h dtype", h_dtype,
(mstype.float16, mstype.float32), self.name)
b_dtype = h_dtype
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype

@ -2532,7 +2532,11 @@ test_case_other_ops = [
Tensor(np.random.rand(48).astype(np.float16)),
Tensor(np.random.rand(48).astype(np.float16)),
Tensor(np.random.rand(8, 16).astype(np.float16))],
'skip': ['backward']}),
'desc_bprop': [Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
Tensor(np.random.rand(2, 8, 16).astype(np.float16))]}),
]
test_case_quant_ops = [

Loading…
Cancel
Save