|
|
|
@ -5335,35 +5335,41 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|
|
|
|
forget_bias (float): Add forget bias to forget gate biases in order to decrease former scale. Default to 1.0.
|
|
|
|
|
state_is_tuple (bool): If true, state is tensor tuple, containing h and c; If false, one tensor,
|
|
|
|
|
need split first. Default to True.
|
|
|
|
|
activation (str): Activation. Default to "tanh".
|
|
|
|
|
activation (str): Activation. Default to "tanh". Only "tanh" is currently supported.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **x** (Tensor) - Current words. Tensor of shape (`batch_size`, `input_size`).
|
|
|
|
|
The data type must be float16 or float32.
|
|
|
|
|
- **h** (Tensor) - Hidden state last moment. Tensor of shape (`batch_size`, `hidden_size`).
|
|
|
|
|
The data type must be float16 or float32.
|
|
|
|
|
- **c** (Tensor) - Cell state last moment. Tensor of shape (`batch_size`, `hidden_size`).
|
|
|
|
|
- **w** (Tensor) - Weight. Tensor of shape (`4 x hidden_size`, `input_size + hidden_size`, 1, 1).
|
|
|
|
|
- **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`, 1, 1, 1).
|
|
|
|
|
The data type must be float16 or float32.
|
|
|
|
|
- **w** (Tensor) - Weight. Tensor of shape (`input_size + hidden_size`, `4 x hidden_size`).
|
|
|
|
|
The data type must be float16 or float32.
|
|
|
|
|
- **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`).
|
|
|
|
|
The data type must be same as `c`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
- **ct** (Tensor) - Forward :math:`c_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
|
|
|
|
|
- **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`).
|
|
|
|
|
Has the same type with input `c`.
|
|
|
|
|
- **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`). With data type of float16.
|
|
|
|
|
- **it** (Tensor) - Forward :math:`i_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
|
|
|
|
|
Has the same type with input `c`.
|
|
|
|
|
- **jt** (Tensor) - Forward :math:`j_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
|
|
|
|
|
Has the same type with input `c`.
|
|
|
|
|
- **ft** (Tensor) - Forward :math:`f_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
|
|
|
|
|
Has the same type with input `c`.
|
|
|
|
|
- **ot** (Tensor) - Forward :math:`o_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
|
|
|
|
|
Has the same type with input `c`.
|
|
|
|
|
- **tanhct** (Tensor) - Forward :math:`tanh c_t` cache at moment `t`.
|
|
|
|
|
Tensor of shape (`batch_size`, `hidden_size`).
|
|
|
|
|
Tensor of shape (`batch_size`, `hidden_size`). Has the same type with input `c`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'),
|
|
|
|
|
'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]],
|
|
|
|
|
'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]],
|
|
|
|
|
|
|
|
|
|
>>> x = Tensor(np.random.rand(128, 128).astype(np.float16))
|
|
|
|
|
>>> h = Tensor(np.random.rand(128, 128).astype(np.float16))
|
|
|
|
|
>>> c = Tensor(np.random.rand(128, 128).astype(np.float16))
|
|
|
|
|
>>> w = Tensor(np.random.rand(512, 256, 1, 1).astype(np.float16))
|
|
|
|
|
>>> b = Tensor(np.random.rand(512, 1, 1, 1).astype(np.float16))
|
|
|
|
|
>>> x = Tensor(np.random.rand(1, 32).astype(np.float16))
|
|
|
|
|
>>> h = Tensor(np.random.rand(1, 64).astype(np.float16))
|
|
|
|
|
>>> c = Tensor(np.random.rand(1, 64).astype(np.float16))
|
|
|
|
|
>>> w = Tensor(np.random.rand(96, 256).astype(np.float16))
|
|
|
|
|
>>> b = Tensor(np.random.rand(256, ).astype(np.float16))
|
|
|
|
|
>>> lstm = P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh')
|
|
|
|
|
>>> lstm(x, h, c, w, b)
|
|
|
|
|
"""
|
|
|
|
@ -5375,42 +5381,38 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|
|
|
|
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
|
|
|
|
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
|
|
|
|
|
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
|
|
|
|
|
self.add_prim_attr("io_format", "ND")
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape):
|
|
|
|
|
# (batch_size, input_size)
|
|
|
|
|
validator.check_integer("x_shape", len(x_shape), 2, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
|
|
# h and c should be same shape
|
|
|
|
|
validator.check_integer("h_shape", len(h_shape), 2, Rel.EQ, self.name)
|
|
|
|
|
validator.check("h rank", len(h_shape), "c rank", len(c_shape), Rel.EQ, self.name)
|
|
|
|
|
validator.check("h shape", h_shape, "c shape", c_shape, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("b rank", len(b_shape), 4, Rel.EQ, self.name)
|
|
|
|
|
validator.check("w_shape[0]", w_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
|
|
|
|
|
validator.check("w_shape[1]", w_shape[1], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("h rank", len(h_shape), 2, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name)
|
|
|
|
|
validator.check("x_shape[0]", x_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
|
|
|
|
|
validator.check("c_shape[0]", c_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
|
|
|
|
|
validator.check("c_shape[1]", c_shape[1], "h_shape[1]", h_shape[1], Rel.EQ, self.name)
|
|
|
|
|
validator.check("w_shape[1]", w_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
|
|
|
|
|
validator.check("w_shape[0]", w_shape[0], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name)
|
|
|
|
|
validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
|
|
|
|
|
ct_shape = c_shape
|
|
|
|
|
ht_shape = h_shape
|
|
|
|
|
it_shape = h_shape
|
|
|
|
|
jt_shape = h_shape
|
|
|
|
|
ft_shape = h_shape
|
|
|
|
|
ot_shape = h_shape
|
|
|
|
|
tanhct_shape = h_shape
|
|
|
|
|
ht_shape = c_shape
|
|
|
|
|
it_shape = c_shape
|
|
|
|
|
jt_shape = c_shape
|
|
|
|
|
ft_shape = c_shape
|
|
|
|
|
ot_shape = c_shape
|
|
|
|
|
tanhct_shape = c_shape
|
|
|
|
|
|
|
|
|
|
return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape)
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype):
|
|
|
|
|
validator.check_subclass("x", x_dtype, [mstype.tensor], self.name)
|
|
|
|
|
validator.check_subclass("h", h_dtype, [mstype.tensor], self.name)
|
|
|
|
|
validator.check_subclass("c", c_dtype, [mstype.tensor], self.name)
|
|
|
|
|
validator.check_subclass("w", w_dtype, [mstype.tensor], self.name)
|
|
|
|
|
validator.check_subclass("b", b_dtype, [mstype.tensor], self.name)
|
|
|
|
|
validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
validator.check_type_name("b", b_dtype, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype)
|
|
|
|
|
validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
validator.check_tensor_type_same({"h_dtype": h_dtype}, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
validator.check_tensor_type_same({"w_dtype": w_dtype}, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
|
|
|
|
|
args = {"c_dtype": c_dtype, "b_dtype": b_dtype}
|
|
|
|
|
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
|
|
|
|
|
return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InTopK(PrimitiveWithInfer):
|
|
|
|
|