From ba3bde6f8563a0b16401745f4eb5d79a2e83cd86 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Fri, 5 Mar 2021 18:58:32 +0800 Subject: [PATCH] fix lstm output. --- .../optimizer/ascend/ir_fission/pack_fission.cc | 2 +- mindspore/nn/layer/lstm.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc index 17c2eadd97..da03e91c62 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc @@ -60,7 +60,7 @@ AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_ } } new_shape.erase(new_shape.begin() + axis + 1); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {output_shape}, + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {new_shape}, new_pack.get()); return new_pack; } diff --git a/mindspore/nn/layer/lstm.py b/mindspore/nn/layer/lstm.py index 5306be54e0..bc4cf121a6 100755 --- a/mindspore/nn/layer/lstm.py +++ b/mindspore/nn/layer/lstm.py @@ -253,11 +253,14 @@ class LSTM(Cell): x = self.transpose(x, (1, 0, 2)) h, c = hx if self.is_ascend: + x_dtype = F.dtype(x) + h_dtype = F.dtype(h) + c_dtype = F.dtype(c) _check_input_3d(F.shape(h), "h of hx", self.cls_name) _check_input_3d(F.shape(c), "c of hx", self.cls_name) - _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) - _check_input_dtype(F.dtype(h), "h", [mstype.float32, mstype.float16], self.cls_name) - _check_input_dtype(F.dtype(c), "c", [mstype.float32, mstype.float16], self.cls_name) + _check_input_dtype(x_dtype, "x", [mstype.float32, mstype.float16], self.cls_name) + _check_input_dtype(h_dtype, "h", [mstype.float32, mstype.float16], self.cls_name) + _check_input_dtype(c_dtype, "c", [mstype.float32, mstype.float16], self.cls_name) x = self.cast(x, mstype.float16) h = self.cast(h, mstype.float16) c = self.cast(c, mstype.float16) @@ -265,6 +268,9 @@ class LSTM(Cell): x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list) else: x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list) + x = self.cast(x, x_dtype) + h = self.cast(h, h_dtype) + c = self.cast(c, c_dtype) else: x, h, c, _, _ = self.lstm(x, h, c, self.weight) if self.batch_first: