fix lstm output.

pull/12933/head
liuxiao93 4 years ago
parent fdd4339f50
commit ba3bde6f85

@ -60,7 +60,7 @@ AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_
}
}
new_shape.erase(new_shape.begin() + axis + 1);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {output_shape},
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {new_shape},
new_pack.get());
return new_pack;
}

@ -253,11 +253,14 @@ class LSTM(Cell):
x = self.transpose(x, (1, 0, 2))
h, c = hx
if self.is_ascend:
x_dtype = F.dtype(x)
h_dtype = F.dtype(h)
c_dtype = F.dtype(c)
_check_input_3d(F.shape(h), "h of hx", self.cls_name)
_check_input_3d(F.shape(c), "c of hx", self.cls_name)
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(h), "h", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(c), "c", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(x_dtype, "x", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(h_dtype, "h", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(c_dtype, "c", [mstype.float32, mstype.float16], self.cls_name)
x = self.cast(x, mstype.float16)
h = self.cast(h, mstype.float16)
c = self.cast(c, mstype.float16)
@ -265,6 +268,9 @@ class LSTM(Cell):
x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list)
else:
x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list)
x = self.cast(x, x_dtype)
h = self.cast(h, h_dtype)
c = self.cast(c, c_dtype)
else:
x, h, c, _, _ = self.lstm(x, h, c, self.weight)
if self.batch_first:

Loading…
Cancel
Save