|
|
|
@ -1884,6 +1884,7 @@ class SmoothL1Loss(PrimitiveWithInfer):
|
|
|
|
|
validator.check_value_type('beta', beta, [float], self.name)
|
|
|
|
|
validator.check('beta', beta, '', 0, Rel.GT, self.name)
|
|
|
|
|
self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output'])
|
|
|
|
|
self.add_prim_attr('sigma', beta)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, prediction, target):
|
|
|
|
|
validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name)
|
|
|
|
|