|
|
@ -253,11 +253,14 @@ class LSTM(Cell):
|
|
|
|
x = self.transpose(x, (1, 0, 2))
|
|
|
|
x = self.transpose(x, (1, 0, 2))
|
|
|
|
h, c = hx
|
|
|
|
h, c = hx
|
|
|
|
if self.is_ascend:
|
|
|
|
if self.is_ascend:
|
|
|
|
|
|
|
|
x_dtype = F.dtype(x)
|
|
|
|
|
|
|
|
h_dtype = F.dtype(h)
|
|
|
|
|
|
|
|
c_dtype = F.dtype(c)
|
|
|
|
_check_input_3d(F.shape(h), "h of hx", self.cls_name)
|
|
|
|
_check_input_3d(F.shape(h), "h of hx", self.cls_name)
|
|
|
|
_check_input_3d(F.shape(c), "c of hx", self.cls_name)
|
|
|
|
_check_input_3d(F.shape(c), "c of hx", self.cls_name)
|
|
|
|
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
|
|
|
_check_input_dtype(x_dtype, "x", [mstype.float32, mstype.float16], self.cls_name)
|
|
|
|
_check_input_dtype(F.dtype(h), "h", [mstype.float32, mstype.float16], self.cls_name)
|
|
|
|
_check_input_dtype(h_dtype, "h", [mstype.float32, mstype.float16], self.cls_name)
|
|
|
|
_check_input_dtype(F.dtype(c), "c", [mstype.float32, mstype.float16], self.cls_name)
|
|
|
|
_check_input_dtype(c_dtype, "c", [mstype.float32, mstype.float16], self.cls_name)
|
|
|
|
x = self.cast(x, mstype.float16)
|
|
|
|
x = self.cast(x, mstype.float16)
|
|
|
|
h = self.cast(h, mstype.float16)
|
|
|
|
h = self.cast(h, mstype.float16)
|
|
|
|
c = self.cast(c, mstype.float16)
|
|
|
|
c = self.cast(c, mstype.float16)
|
|
|
@ -265,6 +268,9 @@ class LSTM(Cell):
|
|
|
|
x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list)
|
|
|
|
x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list)
|
|
|
|
x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list)
|
|
|
|
|
|
|
|
x = self.cast(x, x_dtype)
|
|
|
|
|
|
|
|
h = self.cast(h, h_dtype)
|
|
|
|
|
|
|
|
c = self.cast(c, c_dtype)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
x, h, c, _, _ = self.lstm(x, h, c, self.weight)
|
|
|
|
x, h, c, _, _ = self.lstm(x, h, c, self.weight)
|
|
|
|
if self.batch_first:
|
|
|
|
if self.batch_first:
|
|
|
|