|
|
@ -6195,7 +6195,7 @@ class CTCLoss(PrimitiveWithInfer):
|
|
|
|
return inputs, inputs
|
|
|
|
return inputs, inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CTCGreedyDecoder(PrimitiveWithInfer):
|
|
|
|
class CTCGreedyDecoder(PrimitiveWithCheck):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Performs greedy decoding on the logits given in inputs.
|
|
|
|
Performs greedy decoding on the logits given in inputs.
|
|
|
|
|
|
|
|
|
|
|
@ -6221,29 +6221,22 @@ class CTCGreedyDecoder(PrimitiveWithInfer):
|
|
|
|
containing sequence log-probability, has the same type as `inputs`.
|
|
|
|
containing sequence log-probability, has the same type as `inputs`.
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
Examples:
|
|
|
|
>>> class CTCGreedyDecoderNet(nn.Cell):
|
|
|
|
|
|
|
|
... def __init__(self):
|
|
|
|
|
|
|
|
... super(CTCGreedyDecoderNet, self).__init__()
|
|
|
|
|
|
|
|
... self.ctc_greedy_decoder = P.CTCGreedyDecoder()
|
|
|
|
|
|
|
|
... self.assert_op = ops.Assert(300)
|
|
|
|
|
|
|
|
...
|
|
|
|
|
|
|
|
... def construct(self, inputs, sequence_length):
|
|
|
|
|
|
|
|
... out = self.ctc_greedy_decoder(inputs,sequence_length)
|
|
|
|
|
|
|
|
... self.assert_op(True, (out[0], out[1], out[2], out[3]))
|
|
|
|
|
|
|
|
... return out[2]
|
|
|
|
|
|
|
|
...
|
|
|
|
|
|
|
|
>>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32)
|
|
|
|
>>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32)
|
|
|
|
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
|
|
|
|
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
|
|
|
|
>>> net = CTCGreedyDecoderNet()
|
|
|
|
>>> ctc_greedy_decoder = ops.CTCGreedyDecoder()
|
|
|
|
>>> output = net(inputs, sequence_length)
|
|
|
|
>>> out1, out2, out3, out4 = ctc_greedy_decoder(inputs, sequence_length)
|
|
|
|
>>> print(output)
|
|
|
|
>>> print(out1, out2, out3, out4)
|
|
|
|
|
|
|
|
[[0 0] [0 1] [1 0]]
|
|
|
|
|
|
|
|
[0 1 0]
|
|
|
|
|
|
|
|
[2 2]
|
|
|
|
|
|
|
|
[[-0.7443749] [0.18251707]]
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
@prim_attr_register
|
|
|
|
def __init__(self, merge_repeated=True):
|
|
|
|
def __init__(self, merge_repeated=True):
|
|
|
|
self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name)
|
|
|
|
self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name)
|
|
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, inputs_shape, sequence_length_shape):
|
|
|
|
def check_shape(self, inputs_shape, sequence_length_shape):
|
|
|
|
validator.check_int(len(inputs_shape), 3, Rel.EQ, "inputs rank", self.name)
|
|
|
|
validator.check_int(len(inputs_shape), 3, Rel.EQ, "inputs rank", self.name)
|
|
|
|
validator.check_int(len(sequence_length_shape), 1, Rel.EQ, "sequence_length rank", self.name)
|
|
|
|
validator.check_int(len(sequence_length_shape), 1, Rel.EQ, "sequence_length rank", self.name)
|
|
|
|
validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size',
|
|
|
|
validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size',
|
|
|
@ -6255,7 +6248,7 @@ class CTCGreedyDecoder(PrimitiveWithInfer):
|
|
|
|
log_probability_shape = [inputs_shape[1], 1]
|
|
|
|
log_probability_shape = [inputs_shape[1], 1]
|
|
|
|
return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape
|
|
|
|
return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape
|
|
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, inputs_dtype, sequence_length_dtype):
|
|
|
|
def check_dtype(self, inputs_dtype, sequence_length_dtype):
|
|
|
|
validator.check_tensor_dtype_valid("inputs_dtype", inputs_dtype, [mstype.float32, mstype.double], self.name)
|
|
|
|
validator.check_tensor_dtype_valid("inputs_dtype", inputs_dtype, [mstype.float32, mstype.double], self.name)
|
|
|
|
validator.check_tensor_dtype_valid("sequence_length_dtype", sequence_length_dtype, [mstype.int32], self.name)
|
|
|
|
validator.check_tensor_dtype_valid("sequence_length_dtype", sequence_length_dtype, [mstype.int32], self.name)
|
|
|
|
decoded_type = mstype.tensor_type(mstype.int64)
|
|
|
|
decoded_type = mstype.tensor_type(mstype.int64)
|
|
|
|