@ -69,6 +69,7 @@ class ScalarAffine(Bijector):
param=param)
self.abs = P.Abs()
self.oneslike = P.OnesLike()
self.log = log_generic
@property
@ -92,7 +93,7 @@ class ScalarAffine(Bijector):
f(x) = a * x + b
"""
x = self._check_value(x, 'value')
return self.scale * x + self.shift
return self.scale * x + self.shift * self.oneslike(x)
def _inverse(self, y):
r"""