|
|
|
@ -666,9 +666,11 @@ class IGamma(Cell):
|
|
|
|
|
domain_error = self.logicalor(self.less(x, 0), self.less(a, 0))
|
|
|
|
|
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a))
|
|
|
|
|
ax = a * self.log(x) - x - self.lgamma(a)
|
|
|
|
|
boradcastto = P.BroadcastTo(self.shape(ax))
|
|
|
|
|
a = boradcastto(a)
|
|
|
|
|
para_shape = self.shape(ax)
|
|
|
|
|
boradcastto = P.BroadcastTo(para_shape)
|
|
|
|
|
if para_shape != ():
|
|
|
|
|
x = boradcastto(x)
|
|
|
|
|
y = boradcastto(y)
|
|
|
|
|
x_is_zero = self.equal(x, 0)
|
|
|
|
|
if a_dtype == mstype.float64:
|
|
|
|
|
log_maxfloat = self.log_maxfloat64
|
|
|
|
@ -741,7 +743,9 @@ class LBeta(Cell):
|
|
|
|
|
_check_input_dtype("input_x", x_dtype, [mstype.float16, mstype.float32], self.cls_name)
|
|
|
|
|
_check_input_dtype("input_y", y_dtype, x_dtype, self.cls_name)
|
|
|
|
|
x_plus_y = x + y
|
|
|
|
|
boradcastto = P.BroadcastTo(self.shape(x_plus_y))
|
|
|
|
|
para_shape = self.shape(x_plus_y)
|
|
|
|
|
boradcastto = P.BroadcastTo(para_shape)
|
|
|
|
|
if para_shape != ():
|
|
|
|
|
x = boradcastto(x)
|
|
|
|
|
y = boradcastto(y)
|
|
|
|
|
comp_less = self.less(x, y)
|
|
|
|
|