|
|
|
@ -3465,7 +3465,7 @@ class SparseApplyFtrl(PrimitiveWithInfer):
|
|
|
|
|
validator.check_value_type("l1", l1, [float], self.name)
|
|
|
|
|
validator.check_value_type("l2", l2, [float], self.name)
|
|
|
|
|
validator.check_value_type("lr_power", lr_power, [float], self.name)
|
|
|
|
|
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
|
|
|
|
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
|
|
|
|
|
self.l1 = validator.check_number("l1", l1, 0.0, Rel.GE, self.name)
|
|
|
|
|
self.l2 = validator.check_number("l2", l2, 0.0, Rel.GE, self.name)
|
|
|
|
|
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
|
|
|
|
@ -3656,7 +3656,7 @@ class CTCLoss(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, preprocess_collapse_repeated=False, ctc_merge_repeated=False,
|
|
|
|
|
def __init__(self, preprocess_collapse_repeated=False, ctc_merge_repeated=True,
|
|
|
|
|
ignore_longer_outputs_than_inputs=False):
|
|
|
|
|
self.init_prim_io_names(inputs=["inputs", "labels_indices", "labels_values", "sequence_length"],
|
|
|
|
|
outputs=["loss", "gradient"])
|
|
|
|
|