|
|
|
@ -492,7 +492,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|
|
|
|
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`.
|
|
|
|
|
Only `None` is currently supported.
|
|
|
|
|
- **init_h** (Tensor) - Hidden state of initial time.
|
|
|
|
|
Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`, or None.
|
|
|
|
|
Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`.
|
|
|
|
|
The data type must be float16 or float32.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
@ -511,10 +511,9 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|
|
|
|
- **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
|
|
|
|
Has the same data type with input `bais_type`.
|
|
|
|
|
|
|
|
|
|
- If `bias_input`, `bias_hidden` and `init_h` all are `None`, `bias_type` is float32.
|
|
|
|
|
- If `bias_input` and `bias_hidden` both are `None`, `bias_type` is float32.
|
|
|
|
|
- If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`.
|
|
|
|
|
- If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`.
|
|
|
|
|
- Otherwise, `bias_type` is the date type of `init_h`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16))
|
|
|
|
@ -553,8 +552,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|
|
|
|
self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
|
|
|
|
|
self.add_prim_attr("io_format", "ND")
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, winput_shape, whidden_shape,
|
|
|
|
|
binput_shape=None, bhidden_shape=None, seq_shape=None, h_shape=None):
|
|
|
|
|
def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape):
|
|
|
|
|
validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
|
|
|
|
|
validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name)
|
|
|
|
|
validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name)
|
|
|
|
@ -564,7 +562,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|
|
|
|
if winput_shape[-1] % 3 != 0:
|
|
|
|
|
raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.")
|
|
|
|
|
|
|
|
|
|
self.placeholder_index = [3, 4, 5, 6]
|
|
|
|
|
self.placeholder_index = [3, 4, 5]
|
|
|
|
|
if binput_shape is not None:
|
|
|
|
|
validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name)
|
|
|
|
|
validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
|
|
|
|
@ -574,14 +572,12 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|
|
|
|
validator.check("bias_hidden_shape", bhidden_shape,
|
|
|
|
|
"3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
|
|
|
|
|
self.placeholder_index.remove(4)
|
|
|
|
|
if h_shape is not None:
|
|
|
|
|
validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
|
|
|
|
|
validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name)
|
|
|
|
|
validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name)
|
|
|
|
|
self.placeholder_index.remove(6)
|
|
|
|
|
if seq_shape is not None:
|
|
|
|
|
raise ValueError(f"For {self.name}, seq_shape should be None.")
|
|
|
|
|
|
|
|
|
|
validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
|
|
|
|
|
validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name)
|
|
|
|
|
validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name)
|
|
|
|
|
validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]",
|
|
|
|
|
whidden_shape[-1], Rel.EQ, self.name)
|
|
|
|
|
validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name)
|
|
|
|
@ -590,15 +586,15 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|
|
|
|
y_shape = (num_step, batch_size, min(hidden_size, self.num_proj))
|
|
|
|
|
else:
|
|
|
|
|
y_shape = (num_step, batch_size, hidden_size)
|
|
|
|
|
outh_shape = (num_step, batch_size, hidden_size)
|
|
|
|
|
out_shape = (num_step, batch_size, hidden_size)
|
|
|
|
|
self.add_prim_attr("placeholder_index", self.placeholder_index)
|
|
|
|
|
return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape
|
|
|
|
|
return y_shape, out_shape, out_shape, out_shape, out_shape, out_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype,
|
|
|
|
|
binput_dtype=None, bhidden_dtype=None, seq_dtype=None, h_dtype=None):
|
|
|
|
|
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype):
|
|
|
|
|
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name)
|
|
|
|
|
b_dtype = mstype.float32
|
|
|
|
|
if binput_dtype is not None:
|
|
|
|
|
validator.check_tensor_dtype_valid("bias input dtype", binput_dtype,
|
|
|
|
@ -608,10 +604,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|
|
|
|
validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype,
|
|
|
|
|
(mstype.float16, mstype.float32), self.name)
|
|
|
|
|
b_dtype = bhidden_dtype
|
|
|
|
|
elif h_dtype is not None:
|
|
|
|
|
validator.check_tensor_dtype_valid("init_h dtype", h_dtype,
|
|
|
|
|
(mstype.float16, mstype.float32), self.name)
|
|
|
|
|
b_dtype = h_dtype
|
|
|
|
|
|
|
|
|
|
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|