|
|
|
@ -684,6 +684,7 @@ class LBeta(Cell):
|
|
|
|
|
self.shape = P.Shape()
|
|
|
|
|
self.dtype = P.DType()
|
|
|
|
|
self.lgamma = LGamma()
|
|
|
|
|
self.const = P.ScalarToTensor()
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y):
|
|
|
|
|
x_dtype = self.dtype(x)
|
|
|
|
@ -714,9 +715,9 @@ class LBeta(Cell):
|
|
|
|
|
log_gamma_correction_x_y = _log_gamma_correction(x_plus_y, self.minimax_coeff)
|
|
|
|
|
|
|
|
|
|
# Two large arguments case: y >= x >= 8.
|
|
|
|
|
log_beta_two_large = 0.5 * self.log_2pi - 0.5 * self.log(y_max) \
|
|
|
|
|
+ log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \
|
|
|
|
|
+ (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max)
|
|
|
|
|
log_beta_two_large = self.const(0.5 * self.log_2pi, x_dtype) - 0.5 * self.log(y_max) \
|
|
|
|
|
+ log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \
|
|
|
|
|
+ (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max)
|
|
|
|
|
|
|
|
|
|
cancelled_stirling = -1 * (x_min + y_max - 0.5) * self.log1p(x_min / y_max) - x_min * self.log(y_max) + x_min
|
|
|
|
|
correction = log_gamma_correction_y - log_gamma_correction_x_y
|
|
|
|
|