diff --git a/mindspore/nn/layer/lstm.py b/mindspore/nn/layer/lstm.py index ec45fedf93..a2285b2714 100755 --- a/mindspore/nn/layer/lstm.py +++ b/mindspore/nn/layer/lstm.py @@ -93,7 +93,8 @@ class LSTM(Cell): bidirectional (bool): Specifies whether it is a bidirectional LSTM. Default: False. Inputs: - - **input** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`). + - **input** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or + (batch_size, seq_len, `input_size`). - **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32 or mindspore.float16 and shape (num_directions * `num_layers`, batch_size, `hidden_size`). Data type of `hx` must be the same as `input`.