!9680 Make constant numbers to tensors to avoid a bug

From: @peixu_ren
Reviewed-by: @zichun_ye,@sunnybeike
Signed-off-by: @sunnybeike
pull/9680/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 37390519cb

@ -684,6 +684,7 @@ class LBeta(Cell):
self.shape = P.Shape()
self.dtype = P.DType()
self.lgamma = LGamma()
self.const = P.ScalarToTensor()
def construct(self, x, y):
x_dtype = self.dtype(x)
@ -714,9 +715,9 @@ class LBeta(Cell):
log_gamma_correction_x_y = _log_gamma_correction(x_plus_y, self.minimax_coeff)
# Two large arguments case: y >= x >= 8.
log_beta_two_large = 0.5 * self.log_2pi - 0.5 * self.log(y_max) \
+ log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \
+ (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max)
log_beta_two_large = self.const(0.5 * self.log_2pi, x_dtype) - 0.5 * self.log(y_max) \
+ log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \
+ (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max)
cancelled_stirling = -1 * (x_min + y_max - 0.5) * self.log1p(x_min / y_max) - x_min * self.log(y_max) + x_min
correction = log_gamma_correction_y - log_gamma_correction_x_y

Loading…
Cancel
Save