From 136bded1fa2ddfbc210280a82a00b996808a00cf Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Fri, 18 Dec 2020 16:25:44 +0800 Subject: [PATCH] Fix output for Ascend backend of nn.LSTM when dropout is 1.0. --- mindspore/nn/layer/lstm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/layer/lstm.py b/mindspore/nn/layer/lstm.py index d66aa11dea..c738b95a25 100755 --- a/mindspore/nn/layer/lstm.py +++ b/mindspore/nn/layer/lstm.py @@ -154,8 +154,12 @@ class LSTM(Cell): self.concat_2dim = P.Concat(axis=2) self.cast = P.Cast() self.shape = P.Shape() - if dropout != 0: - self.dropout_op = nn.Dropout(float(dropout)) + if dropout < 0 or dropout > 1: + raise ValueError("For LSTM, dropout must be a number in range [0, 1], but got {}".format(dropout)) + if dropout == 1: + self.dropout_op = P.ZerosLike() + else: + self.dropout_op = nn.Dropout(float(1 - dropout)) b0 = np.zeros(gate_size, dtype=np.float16) self.w_list = [] self.b_list = []