|
|
@ -931,19 +931,19 @@ class LSTMGradData(PrimitiveWithInfer):
|
|
|
|
def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape,
|
|
|
|
def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape,
|
|
|
|
hx_shape, cx_shape, reserve_shape, state_shape):
|
|
|
|
hx_shape, cx_shape, reserve_shape, state_shape):
|
|
|
|
# dhy and dcy should be same shape
|
|
|
|
# dhy and dcy should be same shape
|
|
|
|
validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
|
|
|
|
validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
|
|
|
|
validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
|
|
|
|
validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
|
|
|
|
validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
|
|
|
|
|
|
|
|
|
|
|
|
validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
|
|
|
|
validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
|
|
|
|
validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
|
|
|
|
|
|
|
|
|
|
|
|
# dy: (seq_len, batch_size, hidden_size * num_directions)
|
|
|
|
# dy: (seq_len, batch_size, hidden_size * num_directions)
|
|
|
|
validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
|
|
|
|
validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
|
|
|
|
validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name)
|
|
|
|
validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
|
|
|
|
|
|
|
|
|
|
|
|
# (seq_len, batch_size, input_size)
|
|
|
|
# (seq_len, batch_size, input_size)
|
|
|
|
dx_shape = (y_shape[0], y_shape[1], self.input_size)
|
|
|
|
dx_shape = (y_shape[0], y_shape[1], self.input_size)
|
|
|
@ -1015,19 +1015,19 @@ class LSTMGrad(PrimitiveWithInfer):
|
|
|
|
def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape,
|
|
|
|
def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape,
|
|
|
|
dcy_shape, reserve_shape):
|
|
|
|
dcy_shape, reserve_shape):
|
|
|
|
# dhy and dcy should be same shape
|
|
|
|
# dhy and dcy should be same shape
|
|
|
|
validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
|
|
|
|
validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
|
|
|
|
validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
|
|
|
|
validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
|
|
|
|
validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
|
|
|
|
|
|
|
|
|
|
|
|
validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
|
|
|
|
validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
|
|
|
|
validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
|
|
|
|
|
|
|
|
|
|
|
|
# dy: (seq_len, batch_size, hidden_size * num_directions)
|
|
|
|
# dy: (seq_len, batch_size, hidden_size * num_directions)
|
|
|
|
validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
|
|
|
|
validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
|
|
|
|
validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name)
|
|
|
|
validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
|
|
|
|
|
|
|
|
|
|
|
|
# (seq_len, batch_size, input_size)
|
|
|
|
# (seq_len, batch_size, input_size)
|
|
|
|
dx_shape = (y_shape[0], y_shape[1], self.input_size)
|
|
|
|
dx_shape = (y_shape[0], y_shape[1], self.input_size)
|
|
|
@ -1069,7 +1069,7 @@ class DynamicRNNGrad(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape,
|
|
|
|
def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape,
|
|
|
|
c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape):
|
|
|
|
c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape):
|
|
|
|
validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(len(x_shape), 3, "x_shape", self.name)
|
|
|
|
num_step, batch_size, input_size = x_shape
|
|
|
|
num_step, batch_size, input_size = x_shape
|
|
|
|
hidden_size = w_shape[-1] // 4
|
|
|
|
hidden_size = w_shape[-1] // 4
|
|
|
|
if w_shape[-1] % 4 != 0:
|
|
|
|
if w_shape[-1] % 4 != 0:
|
|
|
@ -1575,7 +1575,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
|
|
|
|
def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
|
|
|
|
# dhy and dcy should be same shape
|
|
|
|
# dhy and dcy should be same shape
|
|
|
|
validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(len(c_shape), 2, "c rank", self.name)
|
|
|
|
validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name)
|
|
|
|
validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name)
|
|
|
|
validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
|
|
|
|
validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
|
|
|
|
validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name)
|
|
|
|
validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name)
|
|
|
@ -1624,7 +1624,7 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
|
|
|
|
self.add_prim_attr("io_format", "HWCN")
|
|
|
|
self.add_prim_attr("io_format", "HWCN")
|
|
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, h_shape, dgate_shape):
|
|
|
|
def infer_shape(self, x_shape, h_shape, dgate_shape):
|
|
|
|
validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(len(x_shape), 2, "x rank", self.name)
|
|
|
|
validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name)
|
|
|
|
validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name)
|
|
|
|
validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name)
|
|
|
|
validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name)
|
|
|
|
validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name)
|
|
|
|
validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name)
|
|
|
@ -1656,8 +1656,8 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer):
|
|
|
|
self.add_prim_attr("io_format", "ND")
|
|
|
|
self.add_prim_attr("io_format", "ND")
|
|
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, dgate_shape, w_shape):
|
|
|
|
def infer_shape(self, dgate_shape, w_shape):
|
|
|
|
validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name)
|
|
|
|
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
|
|
|
|
validator.check_equal_int(len(w_shape), 2, "w rank", self.name)
|
|
|
|
validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
|
|
|
validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
|
|
|
batch_size = dgate_shape[0]
|
|
|
|
batch_size = dgate_shape[0]
|
|
|
|
hidden_size = dgate_shape[1] // 4
|
|
|
|
hidden_size = dgate_shape[1] // 4
|
|
|
|