LSTM Ascend parameter type fp16 change to fp32

pull/12302/head
ttudu 4 years ago
parent 461e020bb9
commit f4193137e5

@ -198,15 +198,15 @@ class StackLSTMAscend(nn.Cell):
# forward weight init
w_np_fw = np.random.uniform(-stdv,
stdv,
(input_size_list[i] + hidden_size, hidden_size * 4)).astype(np.float16)
(input_size_list[i] + hidden_size, hidden_size * 4)).astype(np.float32)
w_fw = Parameter(initializer(Tensor(w_np_fw), w_np_fw.shape), name="w_fw_layer" + str(i))
weights_fw.append(w_fw)
# forward bias init
if has_bias:
b_fw = np.random.uniform(-stdv, stdv, (hidden_size * 4)).astype(np.float16)
b_fw = np.random.uniform(-stdv, stdv, (hidden_size * 4)).astype(np.float32)
b_fw = Parameter(initializer(Tensor(b_fw), b_fw.shape), name="b_fw_layer" + str(i))
else:
b_fw = np.zeros((hidden_size * 4)).astype(np.float16)
b_fw = np.zeros((hidden_size * 4)).astype(np.float32)
b_fw = Parameter(initializer(Tensor(b_fw), b_fw.shape), name="b_fw_layer" + str(i))
bias_fw.append(b_fw)
@ -214,21 +214,21 @@ class StackLSTMAscend(nn.Cell):
# backward weight init
w_np_bw = np.random.uniform(-stdv,
stdv,
(input_size_list[i] + hidden_size, hidden_size * 4)).astype(np.float16)
(input_size_list[i] + hidden_size, hidden_size * 4)).astype(np.float32)
w_bw = Parameter(initializer(Tensor(w_np_bw), w_np_bw.shape), name="w_bw_layer" + str(i))
weights_bw.append(w_bw)
# backward bias init
if has_bias:
b_bw = np.random.uniform(-stdv, stdv, (hidden_size * 4)).astype(np.float16)
b_bw = np.random.uniform(-stdv, stdv, (hidden_size * 4)).astype(np.float32)
b_bw = Parameter(initializer(Tensor(b_bw), b_bw.shape), name="b_bw_layer" + str(i))
else:
b_bw = np.zeros((hidden_size * 4)).astype(np.float16)
b_bw = np.zeros((hidden_size * 4)).astype(np.float32)
b_bw = Parameter(initializer(Tensor(b_bw), b_bw.shape), name="b_bw_layer" + str(i))
bias_bw.append(b_bw)
# layer init
self.lstm = LSTM_Ascend(bidirectional=bidirectional)
self.lstm = LSTM_Ascend(bidirectional=bidirectional).to_float(mstype.float16)
self.weight_fw = ParameterTuple(tuple(weights_fw))
self.weight_bw = ParameterTuple(tuple(weights_bw))

Loading…
Cancel
Save