diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index cfadca67ed..dc3070701d 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -666,9 +666,11 @@ class IGamma(Cell): domain_error = self.logicalor(self.less(x, 0), self.less(a, 0)) use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) ax = a * self.log(x) - x - self.lgamma(a) - boradcastto = P.BroadcastTo(self.shape(ax)) - a = boradcastto(a) - x = boradcastto(x) + para_shape = self.shape(ax) + boradcastto = P.BroadcastTo(para_shape) + if para_shape != (): + x = boradcastto(x) + y = boradcastto(y) x_is_zero = self.equal(x, 0) if a_dtype == mstype.float64: log_maxfloat = self.log_maxfloat64 @@ -741,9 +743,11 @@ class LBeta(Cell): _check_input_dtype("input_x", x_dtype, [mstype.float16, mstype.float32], self.cls_name) _check_input_dtype("input_y", y_dtype, x_dtype, self.cls_name) x_plus_y = x + y - boradcastto = P.BroadcastTo(self.shape(x_plus_y)) - x = boradcastto(x) - y = boradcastto(y) + para_shape = self.shape(x_plus_y) + boradcastto = P.BroadcastTo(para_shape) + if para_shape != (): + x = boradcastto(x) + y = boradcastto(y) comp_less = self.less(x, y) x_min = self.select(comp_less, x, y) y_max = self.select(comp_less, y, x)