|
|
|
@ -3410,7 +3410,7 @@ class MirrorPad(PrimitiveWithInfer):
|
|
|
|
|
'value': None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ComputeAccidentalHits(PrimitiveWithInfer):
|
|
|
|
|
class ComputeAccidentalHits(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
Compute accidental hits of sampled classes which happen to match target classes.
|
|
|
|
|
|
|
|
|
@ -3455,17 +3455,18 @@ class ComputeAccidentalHits(PrimitiveWithInfer):
|
|
|
|
|
self.init_prim_io_names(inputs=['true_classes', 'sampled_candidates'],
|
|
|
|
|
outputs=['indices', 'ids', 'weights'])
|
|
|
|
|
validator.check_value_type("num_true", num_true, [int], self.name)
|
|
|
|
|
validator.check_number("num_true", num_true, 1, Rel.GE, self.name)
|
|
|
|
|
self.num_true = num_true
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, true_classes_shape, sampled_candidates_shape):
|
|
|
|
|
validator.check("true_classes shape rank", len(true_classes_shape), "expect", 2, Rel.EQ, self.name)
|
|
|
|
|
validator.check("sampled_candidates shape rank", len(sampled_candidates_shape), "expect", 1, Rel.EQ, self.name)
|
|
|
|
|
validator.check_int(true_classes_shape[1], self.num_true, Rel.EQ, 'true_classes_shape', self.name)
|
|
|
|
|
def check_shape(self, true_classes_shape, sampled_candidates_shape):
|
|
|
|
|
validator.check_int(len(true_classes_shape), 2, Rel.EQ, 'dim of true_classes', self.name)
|
|
|
|
|
validator.check_int(len(sampled_candidates_shape), 1, Rel.EQ, 'dim of sampled_candidates', self.name)
|
|
|
|
|
validator.check("true_classes shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)
|
|
|
|
|
|
|
|
|
|
indices_len = -1
|
|
|
|
|
return (indices_len,), (indices_len,), (indices_len,)
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, true_classes_type, sampled_candidates_type):
|
|
|
|
|
def check_dtype(self, true_classes_type, sampled_candidates_type):
|
|
|
|
|
validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name)
|
|
|
|
|
validator.check_subclass("sampled_candidates_type", sampled_candidates_type, mstype.tensor, self.name)
|
|
|
|
|
valid_types = (mstype.int32, mstype.int64)
|
|
|
|
@ -6107,13 +6108,13 @@ class CTCLoss(PrimitiveWithInfer):
|
|
|
|
|
>>> ctc_loss = ops.CTCLoss()
|
|
|
|
|
>>> loss, gradient = ctc_loss(inputs, labels_indices, labels_values, sequence_length)
|
|
|
|
|
>>> print(loss)
|
|
|
|
|
[0.69121575 0.5381993 ]
|
|
|
|
|
[0.69121575 0.5381993]
|
|
|
|
|
>>> print(gradient)
|
|
|
|
|
[[[ 0.25831494 0.3623634 -0.62067937]
|
|
|
|
|
[ 0.25187883 0.2921483 -0.5440271 ]]
|
|
|
|
|
[[[0.25831494 0.3623634 -0.62067937]
|
|
|
|
|
[0.25187883 0.2921483 -0.5440271]]
|
|
|
|
|
|
|
|
|
|
[[ 0.43522435 0.24408469 0.07787037 ]
|
|
|
|
|
[ 0.29642645 0.4232373 0.06138104 ]]]
|
|
|
|
|
[[0.43522435 0.24408469 0.07787037]
|
|
|
|
|
[0.29642645 0.4232373 0.06138104]]]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|